mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-12 08:56:28 +00:00
When speculative decoding rejects draft tokens and restores the recurrent state checkpoint, the sampler (RNG, grammar, prev tokens) must also be restored to maintain consistency. Without this, the sampler state reflects the rejected draft tokens, leading to potential divergence. Uses common_sampler_clone() to snapshot the sampler before the speculative batch decode, and restores it on rejection.
4284 lines
172 KiB
C++
4284 lines
172 KiB
C++
#include "server-context.h"
|
|
#include "server-common.h"
|
|
#include "server-task.h"
|
|
#include "server-queue.h"
|
|
|
|
#include "common.h"
|
|
#include "llama.h"
|
|
#include "log.h"
|
|
#include "sampling.h"
|
|
#include "speculative.h"
|
|
#include "mtmd.h"
|
|
#include "mtmd-helper.h"
|
|
|
|
#include <fstream>
|
|
#include <iostream>
|
|
#include <regex>
|
|
|
|
static void log_text(const gpt_params & params_base, const std::string & text) {
|
|
if (params_base.minilog) {
|
|
LOG_TEE("%s\n", text.c_str());
|
|
}
|
|
}
|
|
|
|
server_context::~server_context() {
|
|
if (ctx) {
|
|
llama_free(ctx);
|
|
ctx = nullptr;
|
|
}
|
|
|
|
if (model) {
|
|
llama_free_model(model);
|
|
model = nullptr;
|
|
}
|
|
// Free multimodal
|
|
mtmd_free(mctx);
|
|
// Free draft model and context if they exist
|
|
if (ctx_draft) {
|
|
llama_free(ctx_draft);
|
|
ctx_draft = nullptr;
|
|
}
|
|
if (model_draft) {
|
|
llama_free_model(model_draft);
|
|
model_draft = nullptr;
|
|
}
|
|
|
|
// Clear any sampling context
|
|
for (server_slot& slot : slots) {
|
|
if (slot.ctx_sampling != nullptr) {
|
|
common_sampler_free(slot.ctx_sampling);
|
|
}
|
|
if (slot.spec_ckpt_sampler != nullptr) {
|
|
common_sampler_free(slot.spec_ckpt_sampler);
|
|
}
|
|
if (slot.ctx_dft) {
|
|
llama_free(slot.ctx_dft);
|
|
}
|
|
common_speculative_free(slot.spec);
|
|
llama_batch_free(slot.batch_spec);
|
|
}
|
|
|
|
llama_batch_free(batch);
|
|
}
|
|
|
|
bool server_context::load_model(const gpt_params& params_) {
|
|
params_base = params_;
|
|
|
|
llama_init_result llama_init = llama_init_from_gpt_params(params_base);
|
|
|
|
model = llama_init.model;
|
|
ctx = llama_init.context;
|
|
lora_adapters = llama_init.lora_adapters;
|
|
|
|
if (model == nullptr) {
|
|
LOG_ERROR("unable to load model", { {"model", params_base.model} });
|
|
return false;
|
|
}
|
|
|
|
n_ctx = llama_n_ctx(ctx);
|
|
|
|
add_bos_token = llama_should_add_bos_token(model);
|
|
has_eos_token = llama_add_eos_token(model) != 1;
|
|
|
|
bool has_draft_model = !params_base.speculative.model.empty() || !params_base.speculative.params.empty();
|
|
std::string& mmproj_path = params_base.mmproj.path;
|
|
if (!mmproj_path.empty()) {
|
|
mtmd_context_params mparams = mtmd_context_params_default();
|
|
mparams.use_gpu = params_base.mmproj_use_gpu;
|
|
mparams.print_timings = false;
|
|
mparams.n_threads = params_base.n_threads;
|
|
mparams.flash_attn_type = params_base.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
|
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
|
mparams.image_min_tokens = params_base.image_min_tokens;
|
|
mparams.image_max_tokens = params_base.image_max_tokens;
|
|
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
|
if (mctx == nullptr) {
|
|
LOG_ERROR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
|
|
return false;
|
|
}
|
|
LOG_INFO("loaded multimodal model, '%s'\n", mmproj_path.c_str());
|
|
|
|
//if (params.n_cache_reuse) {
|
|
// params_base.n_cache_reuse = 0;
|
|
// SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
|
|
//}
|
|
|
|
if (has_draft_model) {
|
|
LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal");
|
|
return false;
|
|
}
|
|
if (params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE) {
|
|
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
|
SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled");
|
|
}
|
|
}
|
|
// Load draft model for speculative decoding if specified
|
|
if (has_draft_model) {
|
|
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
|
|
|
|
gpt_params params_dft;
|
|
params_dft.devices = params_base.speculative.devices;
|
|
params_dft.model = params_base.speculative.model;
|
|
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
|
params_dft.rpc_servers = params_base.rpc_servers;
|
|
params_dft.cache_type_k = params_base.speculative.cache_type_k.empty() ? params_base.cache_type_k : params_base.speculative.cache_type_k;
|
|
params_dft.cache_type_v = params_base.speculative.cache_type_v.empty() ? params_base.cache_type_v : params_base.speculative.cache_type_v;
|
|
params_dft.flash_attn = params_base.flash_attn;
|
|
if (!params_base.speculative.params.empty()) {
|
|
auto [argc, argv] = parse_command_line("llama-server " + params_base.speculative.params);
|
|
if (!gpt_params_parse(argc, argv, params_dft)) {
|
|
gpt_params_print_usage(argc, argv, params_dft);
|
|
free_command_line(argc, argv);
|
|
return false;
|
|
};
|
|
free_command_line(argc, argv);
|
|
}
|
|
LOG_INFO("", { {"model", params_dft.model} });
|
|
if (params_dft.n_ctx == 0) {
|
|
params_dft.n_ctx = params_base.speculative.n_ctx;
|
|
}
|
|
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
|
|
params_dft.n_parallel = 1;
|
|
params_dft.n_batch = params_dft.n_ctx;
|
|
|
|
params_base.speculative.mparams_dft.path = params_dft.model; //
|
|
|
|
llama_model_params mparams_dft = common_model_params_to_llama(params_dft);
|
|
|
|
llama_model * model_dft = llama_model_load_from_file(params_dft.model.c_str(), mparams_dft);
|
|
if (model_dft == nullptr) {
|
|
LOG_ERROR("failed to load draft model", { {"model", params_base.speculative.model} });
|
|
return false;
|
|
}
|
|
|
|
cparams_dft = common_context_params_to_llama(params_dft);
|
|
|
|
params_base.speculative.model_dft = model_dft;
|
|
params_base.speculative.cparams_dft = cparams_dft;
|
|
|
|
}
|
|
else if (params_base.has_mtp && llama_model_n_nextn_layer(model) == 0) {
|
|
LOG_WARNING("WARNING: -mtp flag provided, but model has 0 NextN layers. MTP will be disabled.\n", {});
|
|
params_base.has_mtp = false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void server_context::init() {
|
|
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
|
|
|
|
LOG_INFO("initializing slots", { {"n_slots", params_base.n_parallel} });
|
|
|
|
for (int i = 0; i < params_base.n_parallel; i++) {
|
|
server_slot slot;
|
|
|
|
slot.id = i;
|
|
slot.ctx = ctx;
|
|
slot.n_ctx = n_ctx_slot;
|
|
slot.n_predict = params_base.n_predict;
|
|
slot.mctx = mctx;
|
|
slot.cache_tokens.has_mtmd = mctx != nullptr;
|
|
slot.params.think_tokens = params_base.think_tokens;
|
|
if (params_base.think_tokens.exclude) {
|
|
SRV_WRN("Exclude reasoning tokens when selecting slot based on similarity: start: %s, end: %s\nuse `--reasoning-tokens none` to disable.\n", params_base.think_tokens.begin.c_str(), params_base.think_tokens.end.c_str() );
|
|
}
|
|
else {
|
|
SRV_WRN("%s", "Include reasoning tokens when selecting slot based on similarity\nuse `--reasoning-tokens auto` to exclude reasoning tokens.\n");
|
|
}
|
|
LOG_INFO("new slot", {
|
|
{"id_slot", slot.id},
|
|
{"n_ctx_slot", slot.n_ctx}
|
|
});
|
|
|
|
const int ga_n = params_base.grp_attn_n;
|
|
const int ga_w = params_base.grp_attn_w;
|
|
|
|
if (ga_n != 1) {
|
|
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
|
|
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
|
|
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
|
|
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
|
|
|
|
LOG_INFO("slot self-extend", {
|
|
{"id_slot", slot.id},
|
|
{"ga_n", ga_n},
|
|
{"ga_w", ga_w}
|
|
});
|
|
}
|
|
|
|
slot.ga_i = 0;
|
|
slot.ga_n = ga_n;
|
|
slot.ga_w = ga_w;
|
|
|
|
slot.sparams = params_base.sparams;
|
|
|
|
if (params_base.has_mtp) {
|
|
if (llama_model_n_nextn_layer(model) > 0) {
|
|
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
|
|
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
|
|
|
params_base.speculative.cparams_dft = common_context_params_to_llama(params_base);
|
|
params_base.speculative.cparams_dft.mtp = true;
|
|
params_base.speculative.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
|
|
params_base.speculative.cparams_dft.embeddings = true;
|
|
|
|
slot.has_mtp = true;
|
|
slot.params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
|
|
slot.params.speculative.n_min = 0;
|
|
slot.params.speculative.cparams_dft = params_base.speculative.cparams_dft;
|
|
|
|
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
|
|
SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens);
|
|
|
|
SRV_INF("%s\n", "MTP needs embeddings on decode, enabling");
|
|
llama_set_embeddings(ctx, true);
|
|
}
|
|
else {
|
|
SRV_WRN("%s\n", "MTP enabled via flag, but model has 0 NextN layers. Disabling speculative.");
|
|
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
|
slot.has_mtp = false;
|
|
}
|
|
}
|
|
|
|
bool can_spec = true;
|
|
if (!params_base.dry_run) {
|
|
can_spec = common_speculative_is_compat(ctx);
|
|
}
|
|
if (!can_spec) {
|
|
SRV_WRN("%s", "speculative decoding not supported by this context\n");
|
|
}
|
|
// try speculative decoding
|
|
if (can_spec) {
|
|
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
|
|
if (slot.spec) {
|
|
if (mctx) {
|
|
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
|
|
return;
|
|
}
|
|
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
|
|
} else {
|
|
if (slot.has_mtp) {
|
|
SRV_ERR("%s", "failed to initialize MTP speculative context, aborting\n");
|
|
GGML_ABORT("MTP context creation failed");
|
|
} else {
|
|
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
|
|
}
|
|
}
|
|
}
|
|
|
|
slot.reset();
|
|
|
|
slots.push_back(std::move(slot));
|
|
}
|
|
|
|
default_generation_settings_for_props = get_formated_generation(slots.front());
|
|
default_generation_settings_for_props["seed"] = -1;
|
|
|
|
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
|
|
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
|
|
{
|
|
const int32_t n_batch = llama_n_batch(ctx);
|
|
|
|
// only a single seq_id per token is needed
|
|
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
|
}
|
|
|
|
metrics.init();
|
|
|
|
if (params_base.cache_ram_mib != 0) {
|
|
if (params_base.cache_ram_mib < 0) {
|
|
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %s\n", "no limit");
|
|
}
|
|
else {
|
|
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
|
|
}
|
|
LLAMA_LOG_INFO("%s", "use `--cache-ram 0` to disable the prompt cache\n");
|
|
// only apply ram size limit. No token limit for now.
|
|
prompt_cache = std::make_unique<server_prompt_cache>(ctx, params_base.cache_ram_mib, 0);
|
|
}
|
|
else {
|
|
LLAMA_LOG_INFO("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
|
|
}
|
|
|
|
// populate chat template params
|
|
{
|
|
common_chat_templates_ptr chat_templates;
|
|
|
|
try {
|
|
chat_templates = common_chat_templates_init(model, params_base.chat_template);
|
|
|
|
LOG_INF("%s: chat template, example_format: '%s'\n", __func__,
|
|
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
|
|
|
|
}
|
|
catch (const std::exception & e) {
|
|
SRV_ERR("%s: chat template parsing error: %s\n", __func__, e.what());
|
|
SRV_ERR("%s: please consider enabling jinja via --jinja, or use a custom chat template via --chat-template\n", __func__);
|
|
SRV_ERR("%s: for example: --chat-template chatml\n", __func__);
|
|
}
|
|
|
|
// thinking is enabled if:
|
|
// 1. It's not explicitly disabled (reasoning_budget == 0)
|
|
// 2. The chat template supports it
|
|
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
|
|
SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking);
|
|
|
|
chat_params = {
|
|
/* use_jinja */ params_base.use_jinja,
|
|
/* use_peg */ params_base.use_peg,
|
|
/* prefill_assistant */ params_base.prefill_assistant,
|
|
/* reasoning_format */ params_base.reasoning_format,
|
|
/* chat_template_kwargs */ params_base.default_template_kwargs,
|
|
/* tmpls */ std::move(chat_templates),
|
|
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
|
|
/* allow_audio */ mctx ? mtmd_support_audio(mctx) : false,
|
|
/* enable_thinking */ enable_thinking,
|
|
// /* media_path */ params_base.media_path,
|
|
};
|
|
}
|
|
|
|
}
|
|
|
|
|
|
void server_slot::prompt_save(server_prompt_cache& prompt_cache) const {
|
|
assert(server_cached_prompt.data.size() == 0);
|
|
|
|
const size_t cur_size = llama_state_seq_get_size(ctx, id, 0);
|
|
|
|
LLAMA_LOG_INFO(" - saving prompt with length %d, total state size = %.3f MiB\n",
|
|
(int)server_cached_prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
|
|
|
|
auto* cur = prompt_cache.alloc(server_cached_prompt, cur_size);
|
|
if (cur == nullptr) {
|
|
return;
|
|
}
|
|
|
|
llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id, 0);
|
|
}
|
|
|
|
void server_slot::prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) {
|
|
bool res = prompt_cache.load(server_cached_prompt, tokens, ctx, id);
|
|
if (!res) {
|
|
LLAMA_LOG_INFO("failed to load prompt from cache\n");
|
|
}
|
|
}
|
|
|
|
void server_slot::reset() {
|
|
n_prompt_tokens = 0;
|
|
last_gentxt_size = 0;
|
|
generated_text = "";
|
|
truncated = false;
|
|
stopped_eos = false;
|
|
stopped_word = false;
|
|
stopped_limit = false;
|
|
stopping_word = "";
|
|
n_past = 0;
|
|
n_past_prompt = 0;
|
|
n_sent_text = 0;
|
|
drafted.clear();
|
|
i_batch_dft.clear();
|
|
spec_ckpt_valid = false;
|
|
if (spec_ckpt_sampler) {
|
|
common_sampler_free(spec_ckpt_sampler);
|
|
spec_ckpt_sampler = nullptr;
|
|
}
|
|
n_sent_token_probs = 0;
|
|
infill = false;
|
|
ga_i = 0;
|
|
n_past_se = 0;
|
|
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
|
|
logit_bias.clear();
|
|
token_buffer.clear();
|
|
rewind_count = 0;
|
|
n_buffer = 0;
|
|
rewind_status = false;
|
|
|
|
generated_token_probs.clear();
|
|
checkpoint_pos = 0;
|
|
image_just_processed = false;
|
|
do_checkpoint = false;
|
|
|
|
positional_bans.clear();
|
|
ban_phrases.clear();
|
|
ban_regex.clear();
|
|
ban_regex_ci.clear();
|
|
|
|
allow_ruless.clear();
|
|
allow_pieces.clear();
|
|
allow_kws.clear();
|
|
allow_kw_delay = 0;
|
|
allow_idx = 0;
|
|
|
|
// Reset speculative decoding stats
|
|
n_draft_total = 0;
|
|
n_draft_accepted = 0;
|
|
chat_msg = {};
|
|
json_schema = json();
|
|
generated_tool_call_ids.clear();
|
|
|
|
anthropic_thinking_block_started = false;
|
|
anthropic_text_block_started = false;
|
|
|
|
oai_resp_thinking_block_started = false;
|
|
oai_resp_text_block_started = false;
|
|
oai_resp_id.clear();
|
|
oai_resp_reasoning_id.clear();
|
|
oai_resp_message_id.clear();
|
|
oai_resp_fc_id.clear();
|
|
|
|
task.reset();
|
|
}
|
|
|
|
bool server_slot::need_embd() const {
|
|
return embedding || has_mtp;
|
|
}
|
|
|
|
bool server_slot::has_budget(gpt_params& global_params) {
|
|
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
|
return true; // limitless
|
|
}
|
|
|
|
n_remaining = -1;
|
|
|
|
if (params.n_predict != -1) {
|
|
n_remaining = params.n_predict - n_decoded;
|
|
}
|
|
else if (global_params.n_predict != -1) {
|
|
n_remaining = global_params.n_predict - n_decoded;
|
|
}
|
|
|
|
return n_remaining > 0; // no budget
|
|
}
|
|
|
|
bool server_slot::available() const {
|
|
return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE;
|
|
}
|
|
|
|
bool server_slot::is_processing() const {
|
|
return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING;
|
|
}
|
|
|
|
void server_slot::add_token_string(const completion_token_output& token) {
|
|
if (command == SLOT_COMMAND_RELEASE) {
|
|
return;
|
|
}
|
|
generated_token_probs.push_back(token);
|
|
}
|
|
|
|
bool server_slot::can_speculate() const {
|
|
return (!!spec || has_mtp);
|
|
}
|
|
|
|
int server_slot::get_n_draft_max() const {
|
|
if (!can_speculate()) {
|
|
return 0;
|
|
}
|
|
|
|
// determine the max draft that fits the current slot state
|
|
int n_draft_max = params.speculative.n_max;
|
|
|
|
// note: slot.prompt is not yet expanded with the `id` token sampled above
|
|
// also, need to leave space for 1 extra token to allow context shifts
|
|
n_draft_max = std::min(n_draft_max, n_ctx - n_past - 2);
|
|
|
|
if (n_remaining > 0) {
|
|
n_draft_max = std::min(n_draft_max, n_remaining - 1);
|
|
}
|
|
|
|
SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
|
|
|
|
if (n_draft_max < params.speculative.n_min) {
|
|
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, params.speculative.n_min);
|
|
n_draft_max = 0;
|
|
}
|
|
return n_draft_max;
|
|
}
|
|
|
|
void server_slot::release() {
|
|
if (state == SLOT_STATE_PROCESSING) {
|
|
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
|
command = SLOT_COMMAND_RELEASE;
|
|
state = SLOT_STATE_IDLE;
|
|
task.reset();
|
|
llama_decode_reset();
|
|
}
|
|
|
|
}
|
|
|
|
|
|
json server_slot::get_formated_timings() const {
|
|
return json{
|
|
{"prompt_n", n_prompt_tokens_processed},
|
|
{"prompt_ms", t_prompt_processing},
|
|
{"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed},
|
|
{"prompt_per_second", 1e3 / t_prompt_processing * n_prompt_tokens_processed},
|
|
|
|
{"predicted_n", n_decoded},
|
|
{"predicted_ms", t_token_generation},
|
|
{"predicted_per_token_ms", t_token_generation / n_decoded},
|
|
{"predicted_per_second", 1e3 / t_token_generation * n_decoded},
|
|
|
|
{"n_ctx", n_ctx},
|
|
{"n_past", n_past},
|
|
};
|
|
}
|
|
|
|
result_timings server_slot::get_timings() const {
|
|
result_timings timings;
|
|
timings.prompt_n = n_prompt_tokens_processed;
|
|
timings.prompt_ms = t_prompt_processing;
|
|
timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
|
|
timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
|
|
|
timings.predicted_n = n_decoded;
|
|
timings.predicted_ms = t_token_generation;
|
|
timings.predicted_per_token_ms = t_token_generation / n_decoded;
|
|
timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
|
|
|
|
timings.n_ctx = n_ctx;
|
|
timings.n_past = n_past;
|
|
|
|
|
|
// Add speculative metrics
|
|
if (n_draft_total > 0) {
|
|
timings.draft_n = n_draft_total;
|
|
timings.draft_n_accepted = n_draft_accepted;
|
|
}
|
|
|
|
return timings;
|
|
}
|
|
|
|
const common_chat_msg& server_slot::update_chat_msg(std::vector<common_chat_msg_diff>& diffs) {
|
|
auto previous_msg = chat_msg;
|
|
auto new_msg = common_chat_parse(
|
|
generated_text,
|
|
/* is_partial= */ stop != STOP_TYPE_EOS,
|
|
params.oaicompat_chat_syntax);
|
|
if (!new_msg.empty()) {
|
|
new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
|
|
chat_msg = new_msg;
|
|
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
|
|
}
|
|
//LLAMA_LOG_DEBUG("Parsing chat message: %s\n", generated_text.c_str());
|
|
//LLAMA_LOG_DEBUG("Parsing chat message: %s\n", chat_msg.reasoning_content.c_str());
|
|
//LLAMA_LOG_DEBUG("Parsing chat message: %s\n", chat_msg.content.c_str());
|
|
return chat_msg;
|
|
}
|
|
|
|
|
|
size_t server_slot::find_stopping_strings(const std::string& text, const size_t last_token_size, bool is_full_stop) {
|
|
size_t stop_pos = std::string::npos;
|
|
|
|
for (const std::string& word : params.antiprompt) {
|
|
size_t pos;
|
|
|
|
if (is_full_stop) {
|
|
const size_t tmp = word.size() + last_token_size;
|
|
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
|
|
|
|
pos = text.find(word, from_pos);
|
|
}
|
|
else {
|
|
pos = string_find_partial_stop(text, word);
|
|
}
|
|
|
|
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
|
|
if (is_full_stop) {
|
|
stopped_word = true;
|
|
stopping_word = word;
|
|
has_next_token = false;
|
|
}
|
|
stop_pos = pos;
|
|
}
|
|
}
|
|
|
|
return stop_pos;
|
|
}
|
|
|
|
void server_slot::print_timings() const {
|
|
char buffer[512];
|
|
double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
|
|
double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
|
|
|
double t_gen = t_token_generation / n_decoded;
|
|
double n_gen_second = 1e3 / t_token_generation * n_decoded;
|
|
|
|
SLT_INF(*this,
|
|
"\n"
|
|
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
|
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
|
" total time = %10.2f ms / %5d tokens\n",
|
|
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
|
|
t_token_generation, n_decoded, t_gen, n_gen_second,
|
|
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
|
|
|
|
if (n_draft_total > 0) {
|
|
const float draft_ratio = (float)n_draft_accepted / n_draft_total;
|
|
SLT_CNT(*this,
|
|
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
|
|
draft_ratio, n_draft_accepted, n_draft_total
|
|
);
|
|
}
|
|
common_speculative_print_stats(spec, n_gen_second, n_decoded, n_past,
|
|
const_cast<common_params_speculative *>(¶ms.speculative));
|
|
}
|
|
|
|
void server_metrics::init() {
|
|
t_start = ggml_time_us();
|
|
}
|
|
|
|
void server_metrics::on_prompt_eval(const server_slot& slot) {
|
|
n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed;
|
|
n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
|
|
t_prompt_processing += slot.t_prompt_processing;
|
|
t_prompt_processing_total += slot.t_prompt_processing;
|
|
}
|
|
|
|
void server_metrics::on_prediction(const server_slot& slot) {
|
|
n_tokens_predicted_total += slot.n_decoded;
|
|
n_tokens_predicted += slot.n_decoded;
|
|
t_tokens_generation += slot.t_token_generation;
|
|
t_tokens_generation_total += slot.t_token_generation;
|
|
}
|
|
|
|
void server_metrics::reset_bucket() {
|
|
n_prompt_tokens_processed = 0;
|
|
t_prompt_processing = 0;
|
|
n_tokens_predicted = 0;
|
|
t_tokens_generation = 0;
|
|
}
|
|
|
|
std::vector<llama_token> server_context::tokenize(const json& json_prompt, bool add_special) const {
|
|
// TODO: currently, we tokenize using special tokens by default
|
|
// this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
|
|
// but it's better compared to completely ignoring ChatML and other chat templates
|
|
const bool TMP_FORCE_SPECIAL = true;
|
|
|
|
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
|
// or the first element of the json_prompt array is a string.
|
|
std::vector<llama_token> prompt_tokens;
|
|
|
|
if (json_prompt.is_array()) {
|
|
bool first = true;
|
|
for (const auto& p : json_prompt) {
|
|
if (p.is_string()) {
|
|
auto s = p.template get<std::string>();
|
|
|
|
std::vector<llama_token> p;
|
|
if (first) {
|
|
p = ::common_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
|
|
first = false;
|
|
}
|
|
else {
|
|
p = ::common_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
|
|
}
|
|
|
|
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
|
}
|
|
else {
|
|
if (first) {
|
|
first = false;
|
|
}
|
|
|
|
prompt_tokens.push_back(p.template get<llama_token>());
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
auto s = json_prompt.template get<std::string>();
|
|
prompt_tokens = ::common_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
|
|
}
|
|
|
|
return prompt_tokens;
|
|
}
|
|
|
|
server_slot* server_context::get_slot_by_id(int id) {
|
|
for (server_slot& slot : slots) {
|
|
if (slot.id == id) {
|
|
return &slot;
|
|
}
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
float server_context::calculate_slot_f_keep(const server_slot & slot, llama_context * ctx,const server_tokens & a, const server_tokens & b) {
|
|
float f_keep = 0.0f;
|
|
if (!a.empty()) {
|
|
if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && b.size() >= slot.n_ctx) {
|
|
f_keep = a.get_cached_tokens_similarity(slot.ctx, b, slot.params.n_keep + add_bos_token, slot.n_discarded_prompt);
|
|
}
|
|
else {
|
|
f_keep = a.get_cached_tokens_similarity(slot.ctx, b, 0, 0);
|
|
}
|
|
}
|
|
return f_keep;
|
|
}
|
|
|
|
std::pair<common_prefix, float> server_context::calculate_slot_similarity(const server_slot& slot, llama_context* ctx, const server_tokens& a, const server_tokens& b) {
|
|
std::pair<common_prefix, float> sim;
|
|
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
|
common_prefix lcp_len = a.get_common_prefix(slot.ctx, b);
|
|
// fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length
|
|
float sim_cur = a.get_tokens_similarity(slot.ctx, b, 0, 0);
|
|
// handle context shift
|
|
if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && b.size() >= slot.n_ctx) {
|
|
float sim_cur_ctx_shift = a.get_tokens_similarity(slot.ctx, b, slot.n_kept_prompt, slot.n_discarded_prompt);
|
|
if (sim_cur_ctx_shift > sim_cur) {
|
|
sim_cur = sim_cur_ctx_shift;
|
|
}
|
|
}
|
|
sim.first = lcp_len;
|
|
sim.second = sim_cur;
|
|
return sim;
|
|
}
|
|
|
|
void server_context::copy_data_to_cached_prompt(const server_tokens & tokens, server_slot & slot) {
|
|
slot.server_cached_prompt.tokens = tokens.clone(); // copy cache tokens
|
|
slot.server_cached_prompt.n_discarded_prompt = slot.n_discarded_prompt;
|
|
slot.server_cached_prompt.n_kept_prompt = slot.n_kept_prompt;
|
|
slot.server_cached_prompt.think_tokens = slot.params.think_tokens;
|
|
}
|
|
|
|
server_slot* server_context::get_available_slot(const server_task& task) {
|
|
server_slot* ret = nullptr;
|
|
bool update_cache = false;
|
|
|
|
// find the slot that has at least n% prompt similarity
|
|
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
|
|
int max_lcp_len = 0;
|
|
float sim_best = 0;
|
|
|
|
for (server_slot& slot : slots) {
|
|
// skip the slot if it is not available
|
|
if (!slot.available()) {
|
|
continue;
|
|
}
|
|
auto& cache_tokens = slot.cache_tokens;
|
|
// skip the slot if it does not contains prompt
|
|
if (cache_tokens.empty()) {
|
|
continue;
|
|
}
|
|
std::pair<common_prefix, float> sim;
|
|
if (slot.params.think_tokens.exclude) {
|
|
server_tokens cache_tokens_exclude_think = slot.cache_tokens.get_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
|
|
server_tokens prompt_tokens_exclude_think = task.tokens.get_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
|
|
sim = calculate_slot_similarity(slot, ctx, cache_tokens_exclude_think, prompt_tokens_exclude_think);
|
|
}
|
|
else {
|
|
sim = calculate_slot_similarity(slot, ctx, cache_tokens, task.tokens);
|
|
}
|
|
common_prefix lcp_len = sim.first;
|
|
float sim_cur = sim.second;
|
|
|
|
// select the current slot if the criteria match
|
|
if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
|
|
sim_best = sim_cur;
|
|
max_lcp_len = lcp_len.first;
|
|
ret = &slot;
|
|
}
|
|
}
|
|
if (ret != nullptr) {
|
|
LOG_VERBOSE("selected slot by lcp similarity", {
|
|
{"id_slot", ret->id},
|
|
{"max_lcp_len", max_lcp_len},
|
|
{"similarity", sim_best},
|
|
});
|
|
}
|
|
}
|
|
|
|
// find the slot that has been least recently used
|
|
if (ret == nullptr) {
|
|
int64_t t_last = ggml_time_us();
|
|
for (server_slot& slot : slots) {
|
|
// skip the slot if it is not available
|
|
if (!slot.available()) {
|
|
continue;
|
|
}
|
|
// select the current slot if the criteria match
|
|
if (slot.t_last_used < t_last) {
|
|
t_last = slot.t_last_used;
|
|
ret = &slot;
|
|
}
|
|
}
|
|
|
|
if (ret != nullptr) {
|
|
LOG_VERBOSE("selected slot by lru", {
|
|
{"id_slot", ret->id},
|
|
{"t_last", t_last},
|
|
});
|
|
}
|
|
}
|
|
if (ret) {
|
|
auto& tokens = ret->cache_tokens;
|
|
float f_keep = 0;
|
|
size_t cache_token_size = tokens.size();
|
|
if (!tokens.empty()) {
|
|
if (ret->params.think_tokens.exclude) {
|
|
server_tokens cache_exclude_think = tokens.get_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
|
|
server_tokens prompt_exclude_think = task.tokens.get_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
|
|
|
|
cache_token_size = cache_exclude_think.size();
|
|
f_keep = calculate_slot_f_keep(*ret, ret->ctx, cache_exclude_think, prompt_exclude_think);
|
|
}
|
|
else {
|
|
f_keep = calculate_slot_f_keep(*ret, ret->ctx, tokens, task.tokens);
|
|
}
|
|
// if we are about to lose a large portion of the existing context - save it in the prompt cache
|
|
if (f_keep < cache_ram_similarity) {
|
|
update_cache = true;
|
|
}
|
|
}
|
|
|
|
update_cache = update_cache && prompt_cache;
|
|
// cache prompts only for completion tasks
|
|
update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
|
|
|
|
// don't update the cache if the slot's context is above cache_ram_n_min
|
|
update_cache = update_cache && cache_token_size >= cache_ram_n_min;
|
|
|
|
LLAMA_LOG_INFO("======== Prompt cache: cache size: %d, n_keep: %d, n_discarded_prompt: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n",
|
|
(int)tokens.size(), ret->n_kept_prompt, ret->n_discarded_prompt, cache_ram_n_min, f_keep, cache_ram_similarity);
|
|
if (update_cache) {
|
|
const int64_t t_start = ggml_time_us();
|
|
LLAMA_LOG_INFO("updating prompt cache\n");
|
|
// copy cache tokens
|
|
copy_data_to_cached_prompt(tokens, *ret);
|
|
|
|
ret->prompt_save(*prompt_cache);
|
|
LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
|
|
}
|
|
// has prompts saved earlier to load
|
|
if (prompt_cache && !prompt_cache->states.empty()) {
|
|
const int64_t t_start = ggml_time_us();
|
|
copy_data_to_cached_prompt(tokens, *ret);
|
|
|
|
ret->prompt_load(*prompt_cache, task.tokens);
|
|
prompt_cache->update();
|
|
|
|
ret->cache_tokens = ret->server_cached_prompt.tokens.clone(); // recover cache tokens
|
|
ret->n_discarded_prompt = ret->server_cached_prompt.n_discarded_prompt;
|
|
ret->n_kept_prompt = ret->server_cached_prompt.n_kept_prompt;
|
|
|
|
LLAMA_LOG_INFO("prompt cache load took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
|
|
}
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
int32_t server_context::populate_vocab_pieces() {
|
|
const int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
|
if (vocab_pieces.size() == n_vocab) {
|
|
return n_vocab;
|
|
}
|
|
vocab_pieces.clear();
|
|
vocab_pieces.reserve(n_vocab);
|
|
for (int32_t id = 0; id < n_vocab; ++id) {
|
|
vocab_pieces.push_back(common_token_to_piece(ctx, id, true));
|
|
}
|
|
return n_vocab;
|
|
}
|
|
|
|
bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) {
|
|
slot_params defaults;
|
|
defaults.speculative = params_base.speculative;
|
|
|
|
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
|
common_params_sampling default_sparams = params_base.sparams;
|
|
auto& data = task.data;
|
|
const llama_vocab* vocab = llama_model_get_vocab(model);
|
|
if (data.count("__oaicompat") != 0) {
|
|
slot.oaicompat = true;
|
|
slot.oaicompat_model = task.params.oaicompat_model;
|
|
}
|
|
else {
|
|
slot.oaicompat = false;
|
|
slot.oaicompat_model = "";
|
|
}
|
|
slot.params.oaicompat = task.params.oaicompat;
|
|
slot.params.oaicompat_cmpl_id =task.params.oaicompat_cmpl_id;
|
|
|
|
slot.oai_resp_thinking_block_started = false;
|
|
slot.oai_resp_text_block_started = false;
|
|
slot.oai_resp_id = "resp_" + random_string();
|
|
slot.oai_resp_reasoning_id = "rs_" + random_string();
|
|
slot.oai_resp_message_id = "msg_" + random_string();
|
|
slot.oai_resp_fc_id.clear();
|
|
slot.params.timings_per_token = json_value(data, "timings_per_token", false);
|
|
slot.params.stream = json_value(data, "stream", false);
|
|
auto stream_opt = json_value(data, "stream_options", json::object());
|
|
slot.params.include_usage = json_value(stream_opt, "include_usage", false);
|
|
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
|
|
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
|
|
slot.saturate_predict = json_value(data, "saturate_predict", false);
|
|
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
|
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
|
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
|
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
|
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
|
|
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
|
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
|
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
|
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
|
|
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
|
|
slot.sparams.top_n_sigma = json_value(data, "top_n_sigma", default_sparams.top_n_sigma);
|
|
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
|
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
|
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
|
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
|
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
|
|
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
|
|
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
|
|
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
|
|
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
|
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
|
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
|
slot.sparams.adaptive_target = json_value(data, "adaptive_target", default_sparams.adaptive_target);
|
|
slot.sparams.adaptive_decay = json_value(data, "adaptive_decay", default_sparams.adaptive_decay);
|
|
slot.sparams.adaptive_updt_w_cur = json_value(data, "adaptive_updt_w_cur", default_sparams.adaptive_updt_w_cur);
|
|
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
|
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
|
slot.params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
|
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
|
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
|
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
|
|
|
slot.params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
|
|
|
|
// speculative decoding parameters
|
|
slot.params.speculative.n_max = json_value(data, "speculative.n_max", params_base.speculative.n_max);
|
|
slot.params.speculative.n_min = json_value(data, "speculative.n_min", params_base.speculative.n_min);
|
|
slot.params.speculative.p_min = json_value(data, "speculative.p_min", params_base.speculative.p_min);
|
|
|
|
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
|
|
slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0);
|
|
slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
|
|
|
|
slot.params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
|
|
|
|
// Clamp speculative parameters
|
|
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
|
|
slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0);
|
|
slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
|
|
|
|
slot.params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n);
|
|
slot.params.speculative.ngram_size_m = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m);
|
|
slot.params.speculative.ngram_min_hits = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits);
|
|
|
|
slot.params.speculative.ngram_size_n = std::max(std::min(1, (int)slot.params.speculative.ngram_size_n), 1024);
|
|
slot.params.speculative.ngram_size_m = std::max(std::min(1, (int)slot.params.speculative.ngram_size_m), 1024);
|
|
slot.params.speculative.ngram_min_hits = std::max(std::min(1, (int)slot.params.speculative.ngram_min_hits), 1024);
|
|
|
|
|
|
if (slot.sparams.penalty_last_n < -1) {
|
|
throw std::runtime_error("Error: repeat_last_n must be >= -1");
|
|
}
|
|
|
|
if (slot.sparams.dry_penalty_last_n < -1) {
|
|
throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
|
|
}
|
|
|
|
if (slot.sparams.penalty_last_n == -1) {
|
|
// note: should be the slot's context and not the full context, but it's ok
|
|
slot.sparams.penalty_last_n = llama_n_ctx(ctx);
|
|
}
|
|
|
|
if (slot.sparams.dry_penalty_last_n == -1) {
|
|
slot.sparams.dry_penalty_last_n = llama_n_ctx(ctx);
|
|
|
|
}
|
|
if (slot.sparams.dry_base < 1.0f)
|
|
{
|
|
slot.sparams.dry_base = default_sparams.dry_base;
|
|
}
|
|
|
|
// sequence breakers for DRY
|
|
{
|
|
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
|
|
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
|
|
|
|
if (data.contains("dry_sequence_breakers")) {
|
|
slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
|
if (slot.sparams.dry_sequence_breakers.empty()) {
|
|
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
// process "json_schema" and "grammar"
|
|
if (data.contains("json_schema") && !data.contains("grammar")) {
|
|
try {
|
|
auto schema = json_value(data, "json_schema", json::object());
|
|
LLAMA_LOG_DEBUG("JSON schema: %s\n", schema.dump(2).c_str());
|
|
slot.sparams.grammar = json_schema_to_grammar(schema);
|
|
LLAMA_LOG_DEBUG("Converted grammar: %s\n", slot.sparams.grammar.c_str());
|
|
}
|
|
catch (const std::exception& e) {
|
|
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
|
}
|
|
}
|
|
else {
|
|
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
|
LLAMA_LOG_DEBUG("Grammar: %s\n", slot.sparams.grammar.c_str());
|
|
slot.sparams.grammar_lazy = json_value(data, "grammar_lazy", default_sparams.grammar_lazy);
|
|
LLAMA_LOG_DEBUG("Grammar lazy: %s\n", slot.sparams.grammar_lazy ? "true" : "false");
|
|
}
|
|
|
|
if (slot.params.cache_prompt && slot.ga_n != 1) {
|
|
LOG_WARNING("cache_prompt is not supported with group-attention", {});
|
|
slot.params.cache_prompt = false;
|
|
}
|
|
|
|
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
|
// Might be better to reject the request with a 400 ?
|
|
LOG_WARNING("Max tokens to predict exceeds server configuration", {
|
|
{"params.n_predict", slot.params.n_predict},
|
|
{"slot.n_predict", slot.n_predict},
|
|
});
|
|
slot.params.n_predict = slot.n_predict;
|
|
}
|
|
|
|
// infill
|
|
slot.params.input_prefix = json_value(data, "input_prefix", defaults.input_prefix);
|
|
slot.params.input_suffix = json_value(data, "input_suffix", defaults.input_suffix);
|
|
|
|
// get prompt
|
|
if (!task.infill) {
|
|
slot.prompt_tokens = std::move(task.tokens);
|
|
|
|
const auto & prompt = data.find("prompt");
|
|
if (prompt != data.end()) {
|
|
if (prompt->is_string() ||
|
|
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
|
|
slot.prompt = *prompt;
|
|
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) {
|
|
slot.prompt = *prompt;
|
|
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
|
slot.prompt = prompt->at(0);
|
|
}
|
|
}
|
|
}
|
|
|
|
// penalize user-provided tokens
|
|
{
|
|
slot.sparams.penalty_prompt_tokens.clear();
|
|
slot.sparams.use_penalty_prompt_tokens = false;
|
|
|
|
const auto& penalty_prompt = data.find("penalty_prompt");
|
|
|
|
if (penalty_prompt != data.end()) {
|
|
if (penalty_prompt->is_string()) {
|
|
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
|
|
slot.sparams.penalty_prompt_tokens = common_tokenize(model, penalty_prompt_string, false);
|
|
|
|
if (slot.params.n_predict > 0) {
|
|
slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
|
|
}
|
|
slot.sparams.use_penalty_prompt_tokens = true;
|
|
|
|
LOG_VERBOSE("penalty_prompt_tokens", {
|
|
{"id_slot", slot.id},
|
|
{"tokens", slot.sparams.penalty_prompt_tokens},
|
|
});
|
|
}
|
|
else if (penalty_prompt->is_array()) {
|
|
const auto n_tokens = penalty_prompt->size();
|
|
slot.sparams.penalty_prompt_tokens.clear();
|
|
slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
|
|
|
|
const int n_vocab = llama_n_vocab(model);
|
|
for (const auto& penalty_token : *penalty_prompt) {
|
|
if (penalty_token.is_number_integer()) {
|
|
const auto tok = penalty_token.get<llama_token>();
|
|
if (tok >= 0 && tok < n_vocab) {
|
|
slot.sparams.penalty_prompt_tokens.push_back(tok);
|
|
}
|
|
}
|
|
}
|
|
slot.sparams.use_penalty_prompt_tokens = true;
|
|
|
|
LOG_VERBOSE("penalty_prompt_tokens", {
|
|
{"id_slot", slot.id},
|
|
{"tokens", slot.sparams.penalty_prompt_tokens},
|
|
});
|
|
}
|
|
}
|
|
}
|
|
{
|
|
auto it = data.find("chat_format");
|
|
if (it != data.end()) {
|
|
slot.params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
|
|
LLAMA_LOG_DEBUG("Chat format: %s\n", common_chat_format_name(slot.params.oaicompat_chat_syntax.format));
|
|
}
|
|
else {
|
|
slot.params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
|
|
}
|
|
common_reasoning_format reasoning_format = params_base.reasoning_format;
|
|
if (data.contains("reasoning_format")) {
|
|
reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
|
|
}
|
|
slot.params.oaicompat_chat_syntax.reasoning_format = reasoning_format;
|
|
slot.params.oaicompat_chat_syntax.reasoning_in_content = slot.params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
|
|
slot.params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
|
|
if (data.contains("chat_parser")) {
|
|
slot.params.oaicompat_chat_syntax.parser.load(data.at("chat_parser").get<std::string>());
|
|
}
|
|
slot.params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
|
|
}
|
|
{
|
|
|
|
const auto preserved_tokens = data.find("preserved_tokens");
|
|
if (preserved_tokens != data.end()) {
|
|
slot.sparams.preserved_tokens.clear();
|
|
for (const auto& t : *preserved_tokens) {
|
|
auto ids = common_tokenize(model, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
|
|
if (ids.size() == 1) {
|
|
LOG("Preserved token: %d\n", ids[0]);
|
|
slot.sparams.preserved_tokens.insert(ids[0]);
|
|
}
|
|
else {
|
|
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
|
|
LOG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
|
|
}
|
|
}
|
|
}
|
|
const auto grammar_triggers = data.find("grammar_triggers");
|
|
if (grammar_triggers != data.end()) {
|
|
slot.sparams.grammar_triggers.clear();
|
|
for (const auto& t : *grammar_triggers) {
|
|
server_grammar_trigger ct(t);
|
|
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
|
const auto& word = ct.value.value;
|
|
auto ids = common_tokenize(model, word, /* add_special= */ false, /* parse_special= */ true);
|
|
if (ids.size() == 1) {
|
|
auto token = ids[0];
|
|
if (std::find(slot.sparams.preserved_tokens.begin(), slot.sparams.preserved_tokens.end(), (llama_token)token) == slot.sparams.preserved_tokens.end()) {
|
|
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
|
|
}
|
|
LOG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
|
|
common_grammar_trigger trigger;
|
|
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
|
|
trigger.value = word;
|
|
trigger.token = token;
|
|
slot.sparams.grammar_triggers.push_back(std::move(trigger));
|
|
}
|
|
else {
|
|
LOG("Grammar trigger word: `%s`\n", word.c_str());
|
|
slot.sparams.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word });
|
|
}
|
|
}
|
|
else {
|
|
//slot.sparams.grammar_triggers.push_back(ct);
|
|
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
|
|
LLAMA_LOG_DEBUG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
|
|
}
|
|
else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
|
|
LLAMA_LOG_DEBUG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
|
|
}
|
|
else {
|
|
throw std::runtime_error("Unknown grammar trigger type");
|
|
}
|
|
slot.sparams.grammar_triggers.emplace_back(std::move(ct.value));
|
|
}
|
|
}
|
|
}
|
|
|
|
if (slot.sparams.grammar_lazy && slot.sparams.grammar_triggers.empty()) {
|
|
throw std::runtime_error("Error: no triggers set for lazy grammar!");
|
|
}
|
|
}
|
|
|
|
{ // apply logit bias
|
|
const auto& logit_bias = data.find("logit_bias");
|
|
if (logit_bias != data.end() && (logit_bias->is_object() || logit_bias->is_array())) {
|
|
slot.sparams.logit_bias.clear(); // only clear if user sets it
|
|
}
|
|
if (logit_bias != data.end() && logit_bias->is_array()) {
|
|
const int n_vocab = llama_n_vocab(model);
|
|
for (const auto& el : *logit_bias) {
|
|
// TODO: we may want to throw errors here, in case "el" is incorrect
|
|
if (el.is_array() && el.size() == 2) {
|
|
float bias;
|
|
if (el[1].is_number()) {
|
|
bias = el[1].get<float>();
|
|
}
|
|
else if (el[1].is_boolean() && !el[1].get<bool>()) {
|
|
bias = -INFINITY;
|
|
}
|
|
else {
|
|
continue;
|
|
}
|
|
|
|
if (el[0].is_number_integer()) {
|
|
llama_token tok = el[0].get<llama_token>();
|
|
if (tok >= 0 && tok < n_vocab) {
|
|
slot.sparams.logit_bias[tok] = bias;
|
|
}
|
|
}
|
|
else if (el[0].is_string()) {
|
|
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
|
|
for (auto tok : toks) {
|
|
slot.sparams.logit_bias[tok] = bias;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else if (logit_bias != data.end() && logit_bias->is_object()) {
|
|
const int n_vocab = llama_vocab_n_tokens(vocab);
|
|
for (const auto& el : logit_bias->items()) {
|
|
float bias;
|
|
const auto& key = el.key();
|
|
const auto& value = el.value();
|
|
if (value.is_number()) {
|
|
bias = value.get<float>();
|
|
}
|
|
else if (value.is_boolean() && !value.get<bool>()) {
|
|
bias = -INFINITY;
|
|
}
|
|
else {
|
|
continue;
|
|
}
|
|
|
|
char* end;
|
|
llama_token tok = strtol(key.c_str(), &end, 10);
|
|
if (*end == 0) {
|
|
if (tok >= 0 && tok < n_vocab) {
|
|
slot.sparams.logit_bias[tok] = bias;
|
|
}
|
|
}
|
|
else {
|
|
auto toks = common_tokenize(model, key, false);
|
|
for (auto tok : toks) {
|
|
slot.sparams.logit_bias[tok] = bias;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (json_value(data, "ignore_eos", false) && has_eos_token) {
|
|
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
|
}
|
|
|
|
}
|
|
|
|
{
|
|
// ban string
|
|
int32_t banbuffer_size = json_value(data, "banbuffer_size", 0);
|
|
slot.n_buffer = 0; // Ensure buffer calculation starts fresh for this slot
|
|
slot.rewind_count_max = json_value(data, "rewind_count_max", -1);
|
|
|
|
const auto& banned_strings = data.find("banned_strings");
|
|
if (banned_strings != data.end() && banned_strings->is_array()) {
|
|
slot.ban_phrases.clear();
|
|
for (const auto& val : data["banned_strings"]) {
|
|
if (val.is_string()) {
|
|
std::string s = val.get<std::string>();
|
|
if (!s.empty()) {
|
|
s = string_lower(s);
|
|
// Use string length instead of token count
|
|
if (s.length() > slot.n_buffer) {
|
|
slot.n_buffer = s.length();
|
|
}
|
|
slot.ban_phrases.push_back(s);
|
|
}
|
|
}
|
|
}
|
|
std::sort(slot.ban_phrases.begin(), slot.ban_phrases.end(), [](const std::string& a, const std::string& b) {
|
|
return a.length() > b.length();
|
|
});
|
|
} else if (params_base.ban_phrases.size() > 0) {
|
|
if (params_base.n_buffer == 0) {
|
|
slot.ban_phrases.clear();
|
|
std::sort(params_base.ban_phrases.begin(), params_base.ban_phrases.end(), [](const std::string & a, const std::string & b) {
|
|
return a.length() > b.length();
|
|
});
|
|
for (auto & val : params_base.ban_phrases) {
|
|
if (!val.empty()) {
|
|
val = string_lower(val);
|
|
// Use string length instead of token count
|
|
if (val.length() > slot.n_buffer) {
|
|
slot.n_buffer = val.length();
|
|
}
|
|
slot.ban_phrases.push_back(val);
|
|
}
|
|
}
|
|
params_base.n_buffer = slot.n_buffer + 1; // buffer is longest string + 1
|
|
} else {
|
|
slot.ban_phrases = params_base.ban_phrases;
|
|
slot.n_buffer = params_base.n_buffer;
|
|
}
|
|
}
|
|
|
|
// ban regex
|
|
slot.ban_regex.clear();
|
|
const auto& banned_regex = data.find("banned_regex");
|
|
if (banned_regex != data.end() && banned_regex->is_array()) {
|
|
for (const auto& val : data["banned_regex"]) {
|
|
if (val.is_string()) {
|
|
std::string s = val.get<std::string>();
|
|
if (!s.empty()) {
|
|
try {
|
|
std::regex re(s);
|
|
slot.ban_regex.push_back(s);
|
|
if (s.length() > slot.n_buffer) {
|
|
slot.n_buffer = s.length();
|
|
}
|
|
} catch (const std::regex_error& e) {
|
|
send_error(task, "Invalid regex in banned_regex: " + s, ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ban regex case insensitive
|
|
slot.ban_regex_ci.clear();
|
|
const auto& banned_regex_ci = data.find("banned_regex_case_insensitive");
|
|
if (banned_regex_ci != data.end() && banned_regex_ci->is_array()) {
|
|
for (const auto& val : data["banned_regex_case_insensitive"]) {
|
|
if (val.is_string()) {
|
|
std::string s = val.get<std::string>();
|
|
if (!s.empty()) {
|
|
try {
|
|
std::regex re(s, std::regex_constants::icase);
|
|
slot.ban_regex_ci.push_back(s);
|
|
if (s.length() > slot.n_buffer) {
|
|
slot.n_buffer = s.length();
|
|
}
|
|
} catch (const std::regex_error& e) {
|
|
send_error(task, "Invalid regex in banned_regex_case_insensitive: " + s, ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if (banbuffer_size > 0) {
|
|
slot.n_buffer = banbuffer_size;
|
|
} else {
|
|
slot.n_buffer = slot.n_buffer + 1; // buffer is longest string/regex + 1
|
|
}
|
|
|
|
slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore
|
|
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
|
|
slot.banned_n = json_value(data, "banned_n", params_base.banned_n);
|
|
}
|
|
|
|
do // populate allowlist biases
|
|
{
|
|
// TODO: JSON parsing for rules and keywords
|
|
slot.allow_ruless = params_base.allow_ruless;
|
|
if (slot.allow_ruless.size() == 0) {
|
|
slot.allow_biasess.clear();
|
|
break;
|
|
}
|
|
slot.allow_kws = params_base.allow_kws;
|
|
|
|
slot.allow_pieces = params_base.allow_pieces;
|
|
const auto& allowlist_piece_array = data.find("allowlist_piece_array");
|
|
if (allowlist_piece_array != data.end() && allowlist_piece_array->is_array()) {
|
|
slot.allow_pieces.clear();
|
|
for (const auto& piece: *allowlist_piece_array) {
|
|
if (piece.is_string()) {
|
|
slot.allow_pieces.push_back(piece.get<std::string>());
|
|
}
|
|
}
|
|
}
|
|
|
|
slot.allow_kw_delay = json_value(data, "allowlist_keyword_delay", params_base.allow_kw_delay);
|
|
// end of allowlist criteria update
|
|
|
|
const int32_t n_vocab = populate_vocab_pieces();
|
|
|
|
std::unordered_set<llama_token> allow_settoken;
|
|
for (const auto& piece: slot.allow_pieces) {
|
|
for (const auto token: common_tokenize(model, piece, false, true)) {
|
|
allow_settoken.insert(token);
|
|
}
|
|
}
|
|
|
|
auto n_rules = slot.allow_ruless.size();
|
|
if (n_rules > slot.allow_kws.size() + 1) {
|
|
// one more rules than keyword, last rules do not expire
|
|
n_rules = slot.allow_kws.size() + 1;
|
|
slot.allow_ruless.resize(n_rules);
|
|
} else if (n_rules < slot.allow_kws.size()) {
|
|
// every rules expire
|
|
slot.allow_kws.resize(n_rules);
|
|
}
|
|
slot.allow_biasess.resize(n_rules);
|
|
|
|
for (size_t i = 0; i < n_rules; ++i) {
|
|
const auto& rules = slot.allow_ruless[i];
|
|
if ((i < slot.allow_ruless_prev.size()) && (rules == slot.allow_ruless_prev[i])) {
|
|
continue;
|
|
}
|
|
LLAMA_LOG_DEBUG("%s: allowlist %zu is new\n", __func__, i);
|
|
|
|
auto& biases = slot.allow_biasess[i];
|
|
biases.resize(n_vocab);
|
|
|
|
std::vector<uint32_t> cpts;
|
|
std::vector<std::string> scripts;
|
|
for (size_t id = 0; id < n_vocab; ++id) {
|
|
const size_t n_cpt = llama_fill_from_utf8(&vocab_pieces[id], &cpts, &scripts);
|
|
float bias = -INFINITY;
|
|
|
|
// each codepoint must be found in
|
|
for (size_t j = 0; j < n_cpt; ++j) {
|
|
bool in_rule = false;
|
|
|
|
// at least one rule
|
|
for (const auto& rule: rules) {
|
|
const bool in_range = (std::get<0>(rule) <= cpts[j]) && (cpts[j] <= std::get<1>(rule));
|
|
in_rule = in_range && ((std::get<2>(rule) == "*") || std::get<2>(rule) == scripts[j]);
|
|
if (in_rule) {
|
|
// earlier rule has higher priority
|
|
bias = std::max(bias, std::get<3>(rule));
|
|
break;
|
|
}
|
|
}
|
|
if (!in_rule) {
|
|
if ((scripts[j] == "common") || (scripts[j] == "inherited")) {
|
|
// for common or inherited codepoints (e.g. whitespace), defer to other codepoints in the token
|
|
continue;
|
|
}
|
|
|
|
// to shadow realm
|
|
bias = -INFINITY;
|
|
break;
|
|
}
|
|
}
|
|
biases[id] = bias;
|
|
}
|
|
|
|
float max_bias = -INFINITY;
|
|
for (const auto& rule: rules) {
|
|
max_bias = std::max(max_bias, std::get<3>(rule));
|
|
}
|
|
for (const auto token: allow_settoken) {
|
|
biases[token] = max_bias;
|
|
}
|
|
}
|
|
} while (false);
|
|
slot.allow_ruless_prev = slot.allow_ruless;
|
|
|
|
if (llama_model_has_recurrent(llama_get_model(slot.ctx))) {
|
|
params_base.can_ban_phrases = false;
|
|
bool do_checkpoint = params_base.ctx_checkpoints_n > 0;
|
|
// make checkpoints only for completion tasks
|
|
do_checkpoint = do_checkpoint && task.type == SERVER_TASK_TYPE_COMPLETION;
|
|
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
|
// checkpoints are created only if:
|
|
// - the model architecture is marked as recurrent or hybrid
|
|
//
|
|
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
|
params_base.do_checkpoint = do_checkpoint;
|
|
if (slot.n_buffer != 0) {
|
|
LLAMA_LOG_WARN("banned strings is not supported by recurrent model, it will be disabled.\n");
|
|
}
|
|
if (params_base.ctx_shift) {
|
|
params_base.ctx_shift = false;
|
|
LOG_WARNING("%s\n", "ctx_shift is not supported by recurrent model, it will be disabled");
|
|
}
|
|
}
|
|
{
|
|
const auto& stop = data.find("stop");
|
|
if (stop != data.end() && stop->is_array()) {
|
|
slot.params.antiprompt.clear();
|
|
for (const auto& word : *stop) {
|
|
if (!word.empty()) {
|
|
slot.params.antiprompt.push_back(word);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
{
|
|
const auto samplers = data.find("samplers");
|
|
if (samplers != data.end()) {
|
|
if (samplers->is_array()) {
|
|
slot.sparams.samplers_sequence = llama_sampling_types_from_names(*samplers, false);
|
|
}
|
|
else if (samplers->is_string()) {
|
|
slot.sparams.samplers_sequence = llama_sampling_types_from_chars(samplers->get<std::string>());
|
|
}
|
|
else {
|
|
slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
|
|
}
|
|
}
|
|
}
|
|
|
|
{
|
|
if (slot.ctx_sampling != nullptr) {
|
|
common_sampler_free(slot.ctx_sampling);
|
|
}
|
|
slot.ctx_sampling = common_sampler_init(model, slot.sparams);
|
|
if (slot.ctx_sampling == nullptr) {
|
|
// for now, the only error that may happen here is invalid grammar
|
|
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
slot.command = SLOT_COMMAND_LOAD_PROMPT;
|
|
// slot.prompt_tokens.clear();
|
|
|
|
LOG_INFO("slot is processing task", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
});
|
|
slot.task = std::make_unique<const server_task>(std::move(task));
|
|
return true;
|
|
}
|
|
|
|
void server_context::kv_cache_clear() {
|
|
LOG_VERBOSE("clearing KV cache", {});
|
|
|
|
// clear the entire KV cache
|
|
llama_kv_cache_clear(ctx);
|
|
clean_kv_cache = false;
|
|
}
|
|
|
|
void server_context::system_prompt_update() {
|
|
LOG_VERBOSE("system prompt update", {
|
|
{"system_prompt", system_prompt},
|
|
});
|
|
|
|
kv_cache_clear();
|
|
system_tokens.clear();
|
|
|
|
if (!system_prompt.empty()) {
|
|
system_tokens = ::common_tokenize(ctx, system_prompt, true);
|
|
|
|
const int32_t n_batch = llama_n_batch(ctx);
|
|
const int32_t n_tokens_prompt = system_tokens.size();
|
|
|
|
for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
|
|
const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
|
|
|
|
common_batch_clear(batch);
|
|
|
|
for (int32_t j = 0; j < n_tokens; ++j) {
|
|
common_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
|
|
}
|
|
|
|
if (llama_decode(ctx, batch) != 0) {
|
|
LOG_ERROR("llama_decode() failed", {});
|
|
return;
|
|
}
|
|
}
|
|
|
|
// assign the system KV cache to all parallel sequences
|
|
for (int32_t i = 1; i <= params_base.n_parallel; ++i) {
|
|
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
|
}
|
|
}
|
|
|
|
system_need_update = false;
|
|
}
|
|
|
|
bool server_context::system_prompt_set(const std::string& sys_prompt) {
|
|
system_prompt = sys_prompt;
|
|
|
|
LOG_VERBOSE("system prompt process", {
|
|
{"system_prompt", system_prompt},
|
|
});
|
|
|
|
// release all slots
|
|
for (server_slot& slot : slots) {
|
|
slot.release();
|
|
}
|
|
|
|
system_need_update = true;
|
|
return true;
|
|
}
|
|
|
|
// keep in sync with process_token(completion_token_output& result, server_slot& slot)
|
|
bool server_context::has_next_token(const completion_token_output& result, server_slot& slot) {
|
|
bool next = true;
|
|
//std::string generate_text = slot.generated_text + result.text_to_send;
|
|
//bool incomplete = validate_utf8(generate_text) < generate_text.size();
|
|
//if (incomplete) {
|
|
// next = true;
|
|
//}
|
|
if (slot.n_decoded > 0 && !slot.has_budget(params_base)) {
|
|
next = false;
|
|
}
|
|
if (llama_token_is_eog(model, result.tok)) {
|
|
next = false;
|
|
}
|
|
auto n_ctx_train = llama_n_ctx_train(model);
|
|
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
|
|
&& slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
|
|
next = false;
|
|
}
|
|
return next;
|
|
}
|
|
|
|
|
|
bool server_context::process_token(completion_token_output& result, server_slot& slot) {
|
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
|
const std::string token_str = result.text_to_send;
|
|
slot.sampled = result.tok;
|
|
|
|
// search stop word and delete it
|
|
slot.last_gentxt_size = slot.generated_text.size();
|
|
slot.generated_text += token_str;
|
|
slot.has_next_token = true;
|
|
|
|
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
|
|
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
|
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
|
}
|
|
|
|
// check if there is incomplete UTF-8 character at the end
|
|
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
|
|
|
|
if (!incomplete) {
|
|
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
|
|
|
const std::string str_test = slot.generated_text.substr(pos);
|
|
bool send_text = true;
|
|
|
|
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
|
|
if (stop_pos != std::string::npos) {
|
|
slot.generated_text.erase(
|
|
slot.generated_text.begin() + pos + stop_pos,
|
|
slot.generated_text.end());
|
|
pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
|
}
|
|
else if (slot.has_next_token && !llama_token_is_eog(model, result.tok)) {
|
|
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
|
|
send_text = stop_pos == std::string::npos;
|
|
}
|
|
|
|
// check if there is any token to predict
|
|
if (send_text) {
|
|
// no send the stop word in the response
|
|
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
|
slot.n_sent_text += result.text_to_send.size();
|
|
// add the token to slot queue and cache
|
|
}
|
|
else {
|
|
result.text_to_send = "";
|
|
}
|
|
|
|
slot.add_token_string(result);
|
|
if (slot.params.stream) {
|
|
send_partial_response(slot, result);
|
|
}
|
|
}
|
|
|
|
if (incomplete) {
|
|
slot.has_next_token = true;
|
|
}
|
|
|
|
// check the limits
|
|
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
|
|
slot.stopped_limit = true;
|
|
slot.has_next_token = false;
|
|
|
|
LOG_VERBOSE("stopped by limit", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_decoded", slot.n_decoded},
|
|
{"n_predict", slot.params.n_predict},
|
|
});
|
|
}
|
|
|
|
if (llama_token_is_eog(model, result.tok)) {
|
|
slot.stopped_eos = true;
|
|
slot.has_next_token = false;
|
|
|
|
LOG_VERBOSE("eos token found", {});
|
|
}
|
|
|
|
auto n_ctx_train = llama_n_ctx_train(model);
|
|
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
|
|
&& slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
|
|
LOG_WARNING("n_predict is not set and self-context extend is disabled."
|
|
" Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", {
|
|
{ "id_slot", slot.id },
|
|
{ "params.n_predict", slot.params.n_predict },
|
|
{ "slot.n_prompt_tokens", slot.n_prompt_tokens },
|
|
{ "slot.n_decoded", slot.n_decoded },
|
|
{ "slot.n_predict", slot.n_predict },
|
|
{ "n_slots", params_base.n_parallel },
|
|
{ "slot.n_ctx", slot.n_ctx },
|
|
{ "n_ctx", n_ctx },
|
|
{ "n_ctx_train", n_ctx_train },
|
|
{ "ga_n", slot.ga_n },
|
|
});
|
|
slot.truncated = true;
|
|
slot.stopped_limit = true;
|
|
slot.has_next_token = false; // stop prediction
|
|
}
|
|
log_text(params_base, "token:"+result.text_to_send);
|
|
LOG_VERBOSE("next token", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"token", result.tok},
|
|
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
|
|
{"has_next_token", slot.has_next_token},
|
|
{"n_remain", slot.n_remaining},
|
|
{"n_decoded", slot.n_decoded},
|
|
{"stopped_eos", slot.stopped_eos},
|
|
{"stopped_word", slot.stopped_word},
|
|
{"stopped_limit", slot.stopped_limit},
|
|
{"stopping_word", slot.stopping_word},
|
|
});
|
|
|
|
return slot.has_next_token; // continue
|
|
}
|
|
|
|
void server_context::populate_token_probs(const server_slot& slot, completion_token_output& result, bool post_sampling, bool special, int idx) {
|
|
size_t n_probs = slot.sparams.n_probs;
|
|
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
|
|
if (post_sampling) {
|
|
const auto* cur_p = common_sampler_get_candidates(slot.ctx_sampling);
|
|
const size_t max_probs = cur_p->size;
|
|
|
|
// set probability for sampled token
|
|
for (size_t i = 0; i < max_probs; i++) {
|
|
if (cur_p->data[i].id == result.tok) {
|
|
result.prob = cur_p->data[i].p;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// set probability for top n_probs tokens
|
|
result.probs.reserve(max_probs);
|
|
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
|
result.probs.push_back({
|
|
cur_p->data[i].id,
|
|
common_token_to_piece(ctx, cur_p->data[i].id, special),
|
|
cur_p->data[i].p
|
|
});
|
|
}
|
|
}
|
|
else {
|
|
auto&& [sampled_token_p, cur] = get_token_probabilities(ctx, idx, result.tok, n_probs);
|
|
|
|
// set probability for sampled token
|
|
result.prob = sampled_token_p;
|
|
|
|
// set probability for top n_probs tokens
|
|
result.probs.reserve(n_probs);
|
|
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
|
result.probs.push_back({
|
|
cur[i].id,
|
|
common_token_to_piece(ctx, cur[i].id, special),
|
|
cur[i].p
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
json server_context::get_formated_generation(const server_slot& slot) const {
|
|
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
|
|
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
|
|
|
std::vector<std::string> samplers_sequence;
|
|
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
|
|
for (const auto& sampler_type : slot.sparams.samplers_sequence) {
|
|
samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
|
|
}
|
|
|
|
auto grammar_triggers = json::array();
|
|
for (const auto& trigger : slot.sparams.grammar_triggers) {
|
|
grammar_triggers.push_back(trigger.to_json<json>());
|
|
}
|
|
|
|
return json{
|
|
{"n_ctx", slot.n_ctx},
|
|
{"n_predict", slot.n_predict}, // Server configured n_predict
|
|
{"model", params_base.model_alias},
|
|
{"seed", slot.sparams.seed},
|
|
{"temperature", slot.sparams.temp},
|
|
{"dynatemp_range", slot.sparams.dynatemp_range},
|
|
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
|
|
{"top_k", slot.sparams.top_k},
|
|
{"top_p", slot.sparams.top_p},
|
|
{"min_p", slot.sparams.min_p},
|
|
{"tfs_z", slot.sparams.tfs_z},
|
|
{"typical_p", slot.sparams.typical_p},
|
|
{"repeat_last_n", slot.sparams.penalty_last_n},
|
|
{"repeat_penalty", slot.sparams.penalty_repeat},
|
|
{"presence_penalty", slot.sparams.penalty_present},
|
|
{"frequency_penalty", slot.sparams.penalty_freq},
|
|
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
|
|
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
|
|
{"dry_multiplier", slot.sparams.dry_multiplier},
|
|
{"dry_base", slot.sparams.dry_base},
|
|
{"dry_allowed_length", slot.sparams.dry_allowed_length},
|
|
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
|
|
{"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
|
|
{"mirostat", slot.sparams.mirostat},
|
|
{"mirostat_tau", slot.sparams.mirostat_tau},
|
|
{"mirostat_eta", slot.sparams.mirostat_eta},
|
|
{"adaptive_target", slot.sparams.adaptive_target},
|
|
{"adaptive_decay", slot.sparams.adaptive_decay},
|
|
{"adaptive_updt_w_cur", slot.sparams.adaptive_updt_w_cur},
|
|
{"penalize_nl", slot.sparams.penalize_nl},
|
|
{"stop", slot.params.antiprompt},
|
|
{"max_tokens", slot.params.n_predict}, // User configured n_predict
|
|
{"n_keep", slot.params.n_keep},
|
|
{"n_discard", slot.params.n_discard},
|
|
{"ignore_eos", ignore_eos},
|
|
{"stream", slot.params.stream},
|
|
{"logit_bias", slot.sparams.logit_bias},
|
|
{"n_probs", slot.sparams.n_probs},
|
|
{"min_keep", slot.sparams.min_keep},
|
|
{"grammar", slot.sparams.grammar},
|
|
{"grammar_triggers", grammar_triggers},
|
|
{"preserved_tokens", slot.sparams.preserved_tokens},
|
|
{"chat_format", common_chat_format_name(slot.params.oaicompat_chat_syntax.format)},
|
|
{"reasoning_format", common_reasoning_format_name(slot.params.oaicompat_chat_syntax.reasoning_format)},
|
|
{"reasoning_in_content", slot.params.oaicompat_chat_syntax.reasoning_in_content},
|
|
{"thinking_forced_open", slot.params.oaicompat_chat_syntax.thinking_forced_open},
|
|
{"samplers", samplers_sequence}
|
|
};
|
|
}
|
|
|
|
void server_context::send_error(const server_task& task, const std::string& error, const enum error_type type) {
|
|
send_error(task.id, task.id_multi, error, type);
|
|
}
|
|
|
|
void server_context::send_error(const server_slot& slot, const std::string& error, const enum error_type type) {
|
|
send_error(slot.id_task, slot.id_multi, error, type);
|
|
}
|
|
|
|
void server_context::send_error(const int id_task, const int id_multi, const std::string& error, const enum error_type type ) {
|
|
LOG_ERROR("task error", {
|
|
{"id_multi", id_multi},
|
|
{"id_task", id_task},
|
|
{"error", error},
|
|
});
|
|
|
|
auto res = std::make_unique<server_task_result_error>();
|
|
res->id = id_task;
|
|
res->id_multi = id_multi;
|
|
res->stop = false;
|
|
res->error = true;
|
|
res->err_type = type;
|
|
res->err_msg = error;
|
|
queue_results.send(std::move(res));
|
|
}
|
|
|
|
// if multimodal is enabled, send an error and return false
|
|
bool server_context::check_no_mtmd(const int id_task) {
|
|
if (mctx) {
|
|
int id_multi = 0;
|
|
send_error(id_task, id_multi, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void server_context::send_partial_response(server_slot& slot, completion_token_output tkn) {
|
|
if (slot.task == nullptr) {
|
|
return;
|
|
}
|
|
auto res = std::make_unique<server_task_result_cmpl_partial>();
|
|
res->final_result = false;
|
|
res->id = slot.id_task;
|
|
res->id_multi = slot.id_multi;
|
|
res->index = slot.task->index;
|
|
res->error = false;
|
|
res->stop = false;
|
|
res->stream = slot.params.stream;
|
|
res->content = tkn.text_to_send;
|
|
res->post_sampling_probs = slot.params.post_sampling_probs;
|
|
res->oaicompat = slot.params.oaicompat;
|
|
res->oaicompat_model = slot.task->params.oaicompat_model;
|
|
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
|
res->oai_resp_id = slot.oai_resp_id;
|
|
res->oai_resp_reasoning_id = slot.oai_resp_reasoning_id;
|
|
res->oai_resp_message_id = slot.oai_resp_message_id;
|
|
res->oai_resp_fc_id = slot.oai_resp_fc_id;
|
|
res->n_decoded = slot.n_decoded;
|
|
res->n_prompt_tokens = slot.n_prompt_tokens;
|
|
res->data = json{
|
|
{"content", tkn.text_to_send},
|
|
{"stop", false},
|
|
{"id_slot", slot.id},
|
|
{"multimodal", false}
|
|
};
|
|
slot.update_chat_msg(res->oaicompat_msg_diffs);
|
|
|
|
res->anthropic_has_reasoning = !slot.chat_msg.reasoning_content.empty();
|
|
|
|
res->anthropic_thinking_block_started = slot.anthropic_thinking_block_started;
|
|
res->anthropic_text_block_started = slot.anthropic_text_block_started;
|
|
|
|
res->oai_resp_thinking_block_started = slot.oai_resp_thinking_block_started;
|
|
res->oai_resp_text_block_started = slot.oai_resp_text_block_started;
|
|
|
|
for (const auto& diff : res->oaicompat_msg_diffs) {
|
|
if (!diff.reasoning_content_delta.empty() && !slot.anthropic_thinking_block_started) {
|
|
slot.anthropic_thinking_block_started = true;
|
|
}
|
|
if (!diff.content_delta.empty() && !slot.anthropic_text_block_started) {
|
|
slot.anthropic_text_block_started = true;
|
|
}
|
|
if (!diff.reasoning_content_delta.empty() && !slot.oai_resp_thinking_block_started) {
|
|
slot.oai_resp_thinking_block_started = true;
|
|
}
|
|
if (!diff.content_delta.empty() && !slot.oai_resp_text_block_started) {
|
|
slot.oai_resp_text_block_started = true;
|
|
}
|
|
if (!diff.tool_call_delta.name.empty()) {
|
|
slot.oai_resp_fc_id = diff.tool_call_delta.id;
|
|
}
|
|
}
|
|
|
|
// populate res->probs_output
|
|
if (slot.sparams.n_probs > 0) {
|
|
res->probs_output = { tkn }; // copy the token probs
|
|
res->data["completion_probabilities"] = probs_vector_to_json(ctx, res->probs_output);
|
|
}
|
|
|
|
if (slot.oaicompat) {
|
|
res->data["oaicompat_token_ctr"] = slot.n_decoded;
|
|
res->data["model"] = slot.oaicompat_model;
|
|
}
|
|
|
|
// populate timings if this is final response or timings_per_token is enabled
|
|
if (slot.params.timings_per_token) {
|
|
res->timings = slot.get_timings();
|
|
}
|
|
queue_results.send(std::move(res));
|
|
}
|
|
|
|
void server_context::send_final_response(server_slot& slot) {
|
|
auto res = std::make_unique<server_task_result_cmpl_final>();
|
|
res->final_result = true;
|
|
res->id = slot.id_task;
|
|
res->id_multi = slot.id_multi;
|
|
res->index = slot.task->index;
|
|
res->error = false;
|
|
res->stop = true; // to do: set value
|
|
res->stream = slot.params.stream;
|
|
res->include_usage = slot.params.include_usage;
|
|
res->content = slot.generated_text;
|
|
res->timings = slot.get_timings();
|
|
res->post_sampling_probs = slot.params.post_sampling_probs;
|
|
res->oaicompat = slot.params.oaicompat;
|
|
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
|
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
|
res->oai_resp_id = slot.oai_resp_id;
|
|
res->oai_resp_reasoning_id = slot.oai_resp_reasoning_id;
|
|
res->oai_resp_message_id = slot.oai_resp_message_id;
|
|
res->n_decoded = slot.n_decoded;
|
|
res->anthropic_thinking_block_started = slot.anthropic_thinking_block_started;
|
|
res->anthropic_text_block_started = slot.anthropic_text_block_started;
|
|
res->n_prompt_tokens = slot.n_prompt_tokens;
|
|
res->oaicompat_model = slot.task->params.oaicompat_model;
|
|
res->data = json{
|
|
{"content", !slot.params.stream ? slot.generated_text : ""},
|
|
{"generated_text", slot.generated_text}, // Always include full text for finish_reason logic
|
|
{"id_slot", slot.id},
|
|
{"stop", true},
|
|
{"model", params_base.model_alias},
|
|
{"tokens_predicted", slot.n_decoded},
|
|
{"tokens_evaluated", slot.n_prompt_tokens},
|
|
{"generation_settings", get_formated_generation(slot)},
|
|
{"prompt", slot.prompt},
|
|
{"truncated", slot.truncated},
|
|
{"stopped_eos", slot.stopped_eos},
|
|
{"stopped_word", slot.stopped_word},
|
|
{"stopped_limit", slot.stopped_limit},
|
|
{"stopping_word", slot.stopping_word},
|
|
{"tokens_cached", slot.n_past},
|
|
{"timings", slot.get_formated_timings()},
|
|
//{"oaicompat_chat_format", slot.params.oaicompat_chat_format},
|
|
};
|
|
|
|
// populate res->probs_output
|
|
if (slot.sparams.n_probs > 0) {
|
|
res->probs_output = std::vector<completion_token_output>(
|
|
slot.generated_token_probs.begin(),
|
|
slot.generated_token_probs.end());
|
|
res->data["completion_probabilities"] = probs_vector_to_json(ctx, res->probs_output);
|
|
}
|
|
|
|
if (slot.oaicompat) {
|
|
res->data["oaicompat_token_ctr"] = slot.n_decoded;
|
|
res->data["model"] = slot.oaicompat_model;
|
|
}
|
|
|
|
queue_results.send(std::move(res));
|
|
}
|
|
|
|
void server_context::send_embedding(const server_slot& slot, const llama_batch& batch) {
|
|
auto res = std::make_unique<server_task_result_embd>();
|
|
res->id = slot.task->id;
|
|
res->index = slot.task->index;
|
|
res->server_task_result::index = slot.task->index;
|
|
res->n_tokens = slot.prompt_tokens.size();
|
|
res->oaicompat = slot.task->params.oaicompat;
|
|
|
|
const int n_embd = llama_model_n_embd(model);
|
|
|
|
std::vector<float> embd_res(n_embd, 0.0f);
|
|
|
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
|
continue;
|
|
}
|
|
|
|
const float* embd = nullptr;
|
|
if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
|
embd = llama_get_embeddings_ith(ctx, i);
|
|
}
|
|
else {
|
|
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
|
}
|
|
|
|
if (embd == nullptr) {
|
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
|
|
|
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
|
continue;
|
|
}
|
|
|
|
// normalize only when there is pooling
|
|
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
|
common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize);
|
|
res->embedding.push_back(embd_res);
|
|
break;
|
|
}
|
|
|
|
res->embedding.emplace_back(embd, embd + n_embd);
|
|
}
|
|
queue_results.send(std::move(res));
|
|
}
|
|
|
|
void server_context::apply_server_biases(server_slot& slot) {
|
|
auto& server_biases = slot.ctx_sampling->server_biases;
|
|
|
|
if (slot.allow_idx < slot.allow_biasess.size()) {
|
|
server_biases = &slot.allow_biasess[slot.allow_idx];
|
|
} else {
|
|
server_biases = nullptr;
|
|
}
|
|
}
|
|
|
|
void server_context::request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs) {
|
|
server_task task;
|
|
task.id = id_task;
|
|
task.id_multi = id_multi;
|
|
task.id_target = 0;
|
|
task.data = std::move(data);
|
|
task.infill = infill;
|
|
task.embedding = embedding;
|
|
task.type = SERVER_TASK_TYPE_COMPLETION;
|
|
task.tokens = std::move(inputs);
|
|
// when a completion task's prompt array is not a singleton, we split it into multiple requests
|
|
// otherwise, it's a single-prompt task, we actually queue it
|
|
// if there's numbers in the prompt array it will be treated as an array of tokens
|
|
if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
|
|
bool numbers = false;
|
|
for (const auto& e : task.data.at("prompt")) {
|
|
if (e.is_number()) {
|
|
numbers = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// NOTE: split_multiprompt_task() does not handle a mix of strings and numbers,
|
|
// it will completely stall the server. I don't know where the bug for this is.
|
|
//
|
|
// if there are numbers, it needs to be treated like a single prompt,
|
|
// queue_tasks handles a mix of strings and numbers just fine.
|
|
if (numbers) {
|
|
queue_tasks.post(std::move(task));
|
|
}
|
|
else {
|
|
split_multiprompt_task(id_task, task);
|
|
}
|
|
}
|
|
else {
|
|
queue_tasks.post(std::move(task));
|
|
}
|
|
}
|
|
|
|
void server_context::request_cancel(int id_task) {
|
|
server_task task;
|
|
task.type = SERVER_TASK_TYPE_CANCEL;
|
|
task.id_target = id_task;
|
|
|
|
queue_tasks.post(std::move(task));
|
|
}
|
|
|
|
void server_context::split_multiprompt_task(int id_multi, server_task& multiprompt_task) {
|
|
const int prompt_count = multiprompt_task.data.at("prompt").size();
|
|
if (prompt_count <= 1) {
|
|
send_error(multiprompt_task, "error while handling multiple prompts");
|
|
return;
|
|
}
|
|
|
|
// generate all the ID for subtask
|
|
std::vector<int> subtask_ids(prompt_count);
|
|
for (int i = 0; i < prompt_count; i++) {
|
|
subtask_ids[i] = queue_tasks.get_new_id();
|
|
}
|
|
|
|
// queue up the multitask so we can track its subtask progression
|
|
queue_tasks.add_multitask(id_multi, subtask_ids);
|
|
|
|
// add subtasks
|
|
for (int i = 0; i < prompt_count; i++) {
|
|
json subtask_data = multiprompt_task.data;
|
|
subtask_data["prompt"] = subtask_data.at("prompt")[i];
|
|
|
|
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
|
request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding,
|
|
std::move(multiprompt_task.tokens));
|
|
}
|
|
}
|
|
|
|
|
|
|
|
static size_t save_checkpoints_to_file(const std::string & filename, const std::list<server_prompt_checkpoint> & checkpoints) {
|
|
if (checkpoints.size() == 0) {
|
|
return 0;
|
|
}
|
|
std::ofstream file(filename, std::ios::binary);
|
|
uint32_t magic = LLAMA_STATE_SEQ_MAGIC;
|
|
file.write(reinterpret_cast<const char *>(&magic), sizeof(magic));
|
|
uint32_t version = LLAMA_STATE_SEQ_VERSION;
|
|
file.write(reinterpret_cast<const char *>(&version), sizeof(version));
|
|
size_t count = checkpoints.size();
|
|
file.write(reinterpret_cast<const char *>(&count), sizeof(count));
|
|
|
|
for (const auto & checkpoint : checkpoints) {
|
|
file.write(reinterpret_cast<const char *>(&checkpoint.pos_min), sizeof(checkpoint.pos_min));
|
|
file.write(reinterpret_cast<const char *>(&checkpoint.pos_max), sizeof(checkpoint.pos_max));
|
|
file.write(reinterpret_cast<const char *>(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt));
|
|
file.write(reinterpret_cast<const char *>(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt));
|
|
size_t data_len = checkpoint.data.size();
|
|
file.write(reinterpret_cast<const char *>(&data_len), sizeof(data_len));
|
|
if (data_len > 0) {
|
|
file.write(reinterpret_cast<const char *>(checkpoint.data.data()), data_len * sizeof(uint8_t));
|
|
}
|
|
}
|
|
size_t pos = file.tellp();
|
|
file.close();
|
|
return pos;
|
|
}
|
|
|
|
static size_t load_checkpoints_from_file(const std::string & filename, std::list<server_prompt_checkpoint> & checkpoints) {
|
|
std::ifstream file(filename, std::ios::binary);
|
|
if (!file.is_open()) {
|
|
return 0;
|
|
}
|
|
checkpoints.clear();
|
|
// version checks
|
|
{
|
|
uint32_t magic;
|
|
file.read(reinterpret_cast<char *>(&magic), sizeof(magic));
|
|
uint32_t version;
|
|
file.read(reinterpret_cast<char *>(&version), sizeof(version));
|
|
|
|
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
|
|
LLAMA_LOG_ERROR("%s: unknown (magic, version) for checkpoint file: %08x, %08x\n", __func__, magic, version);
|
|
return 0;
|
|
}
|
|
}
|
|
// load the checkpoints
|
|
{
|
|
size_t count;
|
|
file.read(reinterpret_cast<char *>(&count), sizeof(count));
|
|
|
|
for (int i = 0; i < count; i++) {
|
|
server_prompt_checkpoint checkpoint;
|
|
file.read(reinterpret_cast<char *>(&checkpoint.pos_min), sizeof(checkpoint.pos_min));
|
|
file.read(reinterpret_cast<char *>(&checkpoint.pos_max), sizeof(checkpoint.pos_max));
|
|
file.read(reinterpret_cast<char *>(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt));
|
|
file.read(reinterpret_cast<char *>(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt));
|
|
|
|
size_t data_len;
|
|
file.read(reinterpret_cast<char *>(&data_len), sizeof(data_len));
|
|
if (data_len > 0) {
|
|
checkpoint.data.resize(data_len);
|
|
file.read(reinterpret_cast<char *>(checkpoint.data.data()), data_len * sizeof(uint8_t));
|
|
}
|
|
checkpoints.push_back(checkpoint);
|
|
}
|
|
}
|
|
size_t pos = file.tellg();
|
|
file.close();
|
|
return pos;
|
|
}
|
|
|
|
static size_t save_server_tokens_to_file(const std::string & filename, const server_tokens & tokens) {
|
|
std::ofstream file(filename, std::ios::binary);
|
|
json token_json = tokens.to_json();
|
|
token_json["magic"] = LLAMA_SERVER_MAGIC;
|
|
token_json["version"] = LLAMA_SERVER_VERSION;
|
|
size_t pos = 0;
|
|
if (file.is_open()) {
|
|
file << token_json;
|
|
pos = file.tellp();
|
|
file.close();
|
|
}
|
|
return pos;
|
|
}
|
|
|
|
static size_t load_server_tokens_from_file(const std::string & filename, server_tokens & tokens) {
|
|
std::ifstream file(filename, std::ios::binary);
|
|
if (!file.is_open()) {
|
|
return 0;
|
|
}
|
|
size_t pos = 0;
|
|
json token_json;
|
|
if (file.is_open()) {
|
|
file >> token_json;
|
|
pos = file.tellg();
|
|
file.close();
|
|
}
|
|
uint32_t magic = token_json.value<uint32_t>("magic", 0);
|
|
uint32_t version = token_json.value<uint32_t>("version", 0);
|
|
if (magic != LLAMA_SERVER_MAGIC || version != LLAMA_SERVER_VERSION) {
|
|
LLAMA_LOG_ERROR("%s: unknown (magic, version) for token file: %08x, %08x\n", __func__, magic, version);
|
|
return 0;
|
|
}
|
|
tokens.from_json(token_json);
|
|
|
|
return pos;
|
|
}
|
|
|
|
void server_context::process_single_task(server_task&& task) {
|
|
switch (task.type) {
|
|
case SERVER_TASK_TYPE_COMPLETION:
|
|
case SERVER_TASK_TYPE_INFILL:
|
|
case SERVER_TASK_TYPE_EMBEDDING:
|
|
case SERVER_TASK_TYPE_RERANK:
|
|
{
|
|
const int id_slot = json_value(task.data, "id_slot", -1);
|
|
|
|
server_slot* slot;
|
|
|
|
if (id_slot != -1) {
|
|
slot = get_slot_by_id(id_slot);
|
|
}
|
|
else {
|
|
slot = get_available_slot(task);
|
|
}
|
|
|
|
if (slot == nullptr) {
|
|
// if no slot is available, we defer this task for processing later
|
|
LOG_VERBOSE("no slot is available", { {"id_task", task.id} });
|
|
queue_tasks.defer(std::move(task));
|
|
break;
|
|
}
|
|
if (!slot->available()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} });
|
|
queue_tasks.defer(std::move(task));
|
|
break;
|
|
}
|
|
|
|
if (task.data.contains("system_prompt")) {
|
|
std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
|
|
system_prompt_set(sys_prompt);
|
|
|
|
for (server_slot& slot : slots) {
|
|
slot.n_past = 0;
|
|
slot.n_past_se = 0;
|
|
}
|
|
}
|
|
|
|
slot->reset();
|
|
|
|
slot->id_task = task.id;
|
|
slot->id_multi = task.id_multi;
|
|
slot->infill = task.infill;
|
|
slot->embedding = task.embedding;
|
|
|
|
if (!launch_slot_with_task(*slot, task)) {
|
|
LOG_ERROR("error while launching slot", task.data);
|
|
break;
|
|
}
|
|
} break;
|
|
case SERVER_TASK_TYPE_CANCEL:
|
|
{
|
|
// release slot linked with the task id
|
|
for (auto& slot : slots) {
|
|
if (slot.id_task == task.id_target) {
|
|
slot.release();
|
|
break;
|
|
}
|
|
}
|
|
} break;
|
|
case SERVER_TASK_TYPE_NEXT_RESPONSE:
|
|
{
|
|
// do nothing
|
|
} break;
|
|
case SERVER_TASK_TYPE_METRICS:
|
|
{
|
|
json slots_data = json::array();
|
|
|
|
int n_idle_slots = 0;
|
|
int n_processing_slots = 0;
|
|
|
|
for (server_slot& slot : slots) {
|
|
json slot_data = get_formated_generation(slot);
|
|
slot_data["id"] = slot.id;
|
|
slot_data["id_task"] = slot.id_task;
|
|
slot_data["state"] = slot.state;
|
|
slot_data["prompt"] = slot.prompt;
|
|
slot_data["next_token"] = {
|
|
{"has_next_token", slot.has_next_token},
|
|
{"n_remain", slot.n_remaining},
|
|
{"n_decoded", slot.n_decoded},
|
|
{"stopped_eos", slot.stopped_eos},
|
|
{"stopped_word", slot.stopped_word},
|
|
{"stopped_limit", slot.stopped_limit},
|
|
{"stopping_word", slot.stopping_word},
|
|
};
|
|
|
|
if (slot_data["state"] == SLOT_STATE_IDLE) {
|
|
n_idle_slots++;
|
|
}
|
|
else {
|
|
n_processing_slots++;
|
|
}
|
|
|
|
slots_data.push_back(slot_data);
|
|
}
|
|
LOG_INFO("slot data", {
|
|
{"id_task", task.id},
|
|
{"n_idle_slots", n_idle_slots},
|
|
{"n_processing_slots", n_processing_slots}
|
|
});
|
|
|
|
LOG_VERBOSE("slot data", {
|
|
{"id_task", task.id},
|
|
{"n_idle_slots", n_idle_slots},
|
|
{"n_processing_slots", n_processing_slots},
|
|
{"slots", slots_data}
|
|
});
|
|
|
|
server_task_result res;
|
|
res.id = task.id;
|
|
res.id_multi = task.id_multi;
|
|
res.stop = true;
|
|
res.error = false;
|
|
res.data = {
|
|
{ "idle", n_idle_slots },
|
|
{ "processing", n_processing_slots },
|
|
{ "deferred", queue_tasks.queue_tasks_deferred.size() },
|
|
{ "t_start", metrics.t_start},
|
|
|
|
{ "n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total},
|
|
{ "t_tokens_generation_total", metrics.t_tokens_generation_total},
|
|
{ "n_tokens_predicted_total", metrics.n_tokens_predicted_total},
|
|
{ "t_prompt_processing_total", metrics.t_prompt_processing_total},
|
|
|
|
{ "n_prompt_tokens_processed", metrics.n_prompt_tokens_processed},
|
|
{ "t_prompt_processing", metrics.t_prompt_processing},
|
|
{ "n_tokens_predicted", metrics.n_tokens_predicted},
|
|
{ "t_tokens_generation", metrics.t_tokens_generation},
|
|
|
|
{ "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)},
|
|
{ "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)},
|
|
|
|
{ "slots", slots_data },
|
|
};
|
|
|
|
if (json_value(task.data, "reset_bucket", false)) {
|
|
metrics.reset_bucket();
|
|
}
|
|
queue_results.send(res);
|
|
} break;
|
|
case SERVER_TASK_TYPE_SLOT_SAVE:
|
|
{
|
|
int id_slot = task.data.at("id_slot");
|
|
server_slot* slot = get_slot_by_id(id_slot);
|
|
if (slot == nullptr) {
|
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
if (!slot->available()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} });
|
|
queue_tasks.defer(std::move(task));
|
|
break;
|
|
}
|
|
|
|
const size_t token_count = slot->cache_tokens.size();
|
|
const int64_t t_start = ggml_time_us();
|
|
|
|
std::string filename = task.data.at("filename");
|
|
std::string filepath = task.data.at("filepath");
|
|
save_server_tokens_to_file(filepath+".tokens.json", slot->cache_tokens);
|
|
size_t saved = save_checkpoints_to_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints);
|
|
|
|
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
|
|
|
|
const int64_t t_end = ggml_time_us();
|
|
const double t_save_ms = (t_end - t_start) / 1000.0;
|
|
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.stop = true;
|
|
result.error = false;
|
|
result.data = json{
|
|
{ "id_slot", id_slot },
|
|
{ "filename", filename },
|
|
{ "n_saved", token_count }, // tokens saved
|
|
{ "n_written", nwrite + saved }, // bytes written
|
|
{ "timings", {
|
|
{ "save_ms", t_save_ms }
|
|
} }
|
|
};
|
|
queue_results.send(result);
|
|
} break;
|
|
case SERVER_TASK_TYPE_SLOT_RESTORE:
|
|
{
|
|
int id_slot = task.data.at("id_slot");
|
|
server_slot* slot = get_slot_by_id(id_slot);
|
|
if (slot == nullptr) {
|
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
if (!slot->available()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} });
|
|
queue_tasks.defer(std::move(task));
|
|
break;
|
|
}
|
|
const int64_t t_start = ggml_time_us();
|
|
|
|
std::string filename = task.data.at("filename");
|
|
std::string filepath = task.data.at("filepath");
|
|
|
|
slot->cache_tokens.resize(slot->n_ctx);
|
|
size_t token_count = 0;
|
|
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
|
|
if (nread == 0) {
|
|
slot->cache_tokens.resize(0);
|
|
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
load_server_tokens_from_file(filepath+".tokens.json", slot->cache_tokens);
|
|
size_t loaded = load_checkpoints_from_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints);
|
|
|
|
const int64_t t_end = ggml_time_us();
|
|
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
|
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.stop = true;
|
|
result.error = false;
|
|
result.data = json{
|
|
{ "id_slot", id_slot },
|
|
{ "filename", filename },
|
|
{ "n_restored", token_count }, // tokens restored
|
|
{ "n_read", nread }, // bytes read
|
|
{ "timings", {
|
|
{ "restore_ms", t_restore_ms }
|
|
} }
|
|
};
|
|
queue_results.send(result);
|
|
} break;
|
|
case SERVER_TASK_TYPE_SLOT_ERASE:
|
|
{
|
|
int id_slot = task.data.at("id_slot");
|
|
server_slot* slot = get_slot_by_id(id_slot);
|
|
if (slot == nullptr) {
|
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
if (!slot->available()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} });
|
|
queue_tasks.defer(std::move(task));
|
|
break;
|
|
}
|
|
// Erase token cache
|
|
const size_t n_erased = slot->cache_tokens.size();
|
|
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
|
|
slot->cache_tokens.keep_first(0);
|
|
//slot->cache_tokens.clear();
|
|
slot->server_cached_prompt.checkpoints.clear();
|
|
slot->server_cached_prompt.data.clear();
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.stop = true;
|
|
result.error = false;
|
|
result.data = json{
|
|
{ "id_slot", id_slot },
|
|
{ "n_erased", n_erased }
|
|
};
|
|
queue_results.send(result);
|
|
} break;
|
|
case SERVER_TASK_TYPE_SET_LORA:
|
|
{
|
|
llama_lora_adapters_apply(ctx, lora_adapters);
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.stop = true;
|
|
result.error = false;
|
|
result.data = json{ { "success", true } };
|
|
queue_results.send(result);
|
|
} break;
|
|
case SERVER_TASK_TYPE_LOAD_CONTROL_VECTOR:
|
|
{
|
|
// Load control vector from file
|
|
std::string path = task.data.at("path");
|
|
float scale = task.data.value("scale", 1.0f);
|
|
int32_t layer_start = task.data.value("layer_start", 1);
|
|
int32_t layer_end = task.data.value("layer_end", llama_n_layer(model));
|
|
|
|
// Check if already loaded
|
|
int cv_id = -1;
|
|
for (size_t i = 0; i < control_vectors.size(); i++) {
|
|
if (control_vectors[i].path == path) {
|
|
control_vectors[i].scale = scale;
|
|
control_vectors[i].layer_start = layer_start;
|
|
control_vectors[i].layer_end = layer_end;
|
|
cv_id = i;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (cv_id == -1) {
|
|
control_vector_container new_cv;
|
|
new_cv.path = path;
|
|
new_cv.scale = scale;
|
|
new_cv.layer_start = layer_start;
|
|
new_cv.layer_end = layer_end;
|
|
new_cv.applied = false;
|
|
|
|
// Load the control vector data
|
|
llama_control_vector_load_info load_info;
|
|
load_info.fname = path;
|
|
load_info.strength = 1.0f; // Don't pre-scale here, we'll scale when applying
|
|
|
|
std::vector<llama_control_vector_load_info> load_infos = { load_info };
|
|
new_cv.data = llama_control_vector_load(load_infos);
|
|
|
|
if (new_cv.data.n_embd == -1) {
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = true;
|
|
result.data = json{{ "success", false }, { "error", "Failed to load control vector from " + path }};
|
|
queue_results.send(result);
|
|
break;
|
|
}
|
|
|
|
// Validate dimension to prevent heap corruption
|
|
if (new_cv.data.n_embd != llama_model_n_embd(model)) {
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = true;
|
|
result.data = json{{ "success", false },
|
|
{ "error", "Vector dimension mismatch" }};
|
|
queue_results.send(result);
|
|
break;
|
|
}
|
|
|
|
control_vectors.push_back(new_cv);
|
|
|
|
cv_id = control_vectors.size() - 1;
|
|
}
|
|
|
|
// Auto-apply control vectors after loading
|
|
if (!apply_control_vectors_internal()) {
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = true;
|
|
result.data = json{{ "success", false }, { "error", "Failed to apply control vectors" }};
|
|
queue_results.send(result);
|
|
break;
|
|
}
|
|
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = false;
|
|
result.data = json{{ "success", true }, { "id", cv_id }};
|
|
queue_results.send(result);
|
|
} break;
|
|
case SERVER_TASK_TYPE_UNLOAD_CONTROL_VECTOR:
|
|
{
|
|
// Validate that "id" field exists and is a number
|
|
if (!task.data.contains("id") || task.data["id"].is_null() || !task.data["id"].is_number()) {
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = true;
|
|
result.data = json{{ "success", false }, { "error", "Missing or invalid 'id' field" }};
|
|
queue_results.send(result);
|
|
break;
|
|
}
|
|
|
|
int id = task.data.at("id");
|
|
|
|
if (id < 0 || id >= (int)control_vectors.size()) {
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = true;
|
|
result.data = json{{ "success", false }, { "error", "Invalid control vector ID" }};
|
|
queue_results.send(result);
|
|
break;
|
|
}
|
|
|
|
// Remove the control vector from the list
|
|
control_vectors.erase(control_vectors.begin() + id);
|
|
|
|
// Reapply remaining control vectors
|
|
if (!apply_control_vectors_internal()) {
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = true;
|
|
result.data = json{{ "success", false }, { "error", "Failed to apply control vectors" }};
|
|
queue_results.send(result);
|
|
break;
|
|
}
|
|
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = false;
|
|
result.data = json{{ "success", true }};
|
|
queue_results.send(result);
|
|
} break;
|
|
case SERVER_TASK_TYPE_SET_CONTROL_VECTOR:
|
|
{
|
|
if (!apply_control_vectors_internal()) {
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = true;
|
|
result.data = json{{ "success", false }, { "error", "Failed to apply control vectors" }};
|
|
queue_results.send(result);
|
|
break;
|
|
}
|
|
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = false;
|
|
result.data = json{{ "success", true }};
|
|
queue_results.send(result);
|
|
} break;
|
|
}
|
|
}
|
|
|
|
bool server_context::apply_control_vectors_internal() {
|
|
llama_control_vector_data combined_cv = { -1, {} };
|
|
|
|
// Check if we have anything to apply
|
|
bool any_active = false;
|
|
for (const auto& cv : control_vectors) {
|
|
if (cv.scale != 0.0f) {
|
|
any_active = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (!any_active) {
|
|
// Clear control vectors if nothing is active
|
|
llama_control_vector_apply(ctx, nullptr, 0, 0, 0, 0);
|
|
return true;
|
|
}
|
|
|
|
// Aggregate control vectors with scaling
|
|
for (auto& cv : control_vectors) {
|
|
if (cv.scale == 0.0f) {
|
|
cv.applied = false;
|
|
continue;
|
|
}
|
|
|
|
if (combined_cv.n_embd == -1) {
|
|
combined_cv.n_embd = cv.data.n_embd;
|
|
combined_cv.data.resize(cv.data.data.size(), 0.0f);
|
|
}
|
|
|
|
for (size_t i = 0; i < cv.data.data.size(); i++) {
|
|
combined_cv.data[i] += cv.data.data[i] * cv.scale;
|
|
}
|
|
cv.applied = true;
|
|
}
|
|
|
|
// Apply combined control vector
|
|
if (combined_cv.n_embd != -1 && !combined_cv.data.empty()) {
|
|
int32_t min_layer_start = INT32_MAX;
|
|
int32_t max_layer_end = 0;
|
|
|
|
for (const auto& cv : control_vectors) {
|
|
if (cv.scale != 0.0f) {
|
|
min_layer_start = std::min(min_layer_start, cv.layer_start);
|
|
max_layer_end = std::max(max_layer_end, cv.layer_end);
|
|
}
|
|
}
|
|
|
|
int err = llama_control_vector_apply(ctx,
|
|
combined_cv.data.data(),
|
|
combined_cv.data.size(),
|
|
combined_cv.n_embd,
|
|
min_layer_start,
|
|
max_layer_end);
|
|
return (err == 0);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void server_context::on_finish_multitask(const server_task_multi& multitask) {
|
|
// all subtasks done == multitask is done
|
|
server_task_result result;
|
|
result.id = multitask.id;
|
|
result.stop = true;
|
|
result.error = false;
|
|
|
|
// collect json results into one json result
|
|
std::vector<json> result_jsons;
|
|
for (const auto& subres : multitask.results) {
|
|
result_jsons.push_back(subres.data);
|
|
result.error = result.error && subres.error;
|
|
}
|
|
result.data = json{
|
|
{ "results", result_jsons }
|
|
};
|
|
|
|
queue_results.send(result);
|
|
}
|
|
|
|
void server_context::print_tokens(const server_tokens& prompt, const server_tokens& cache, size_t start1, size_t start2, size_t length) {
|
|
if (cache.size() > start2) {
|
|
LLAMA_LOG_INFO("cache : %s\n", cache.detokenize(ctx, true, start2, length).c_str());
|
|
}
|
|
if (prompt.size() > start1) {
|
|
LLAMA_LOG_INFO("prompt: %s\n", prompt.detokenize(ctx, true, start1, length).c_str());
|
|
}
|
|
|
|
}
|
|
|
|
void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard) {
|
|
auto kv_keep = slot.cache_tokens.pos_next(n_keep);
|
|
auto kv_discard = slot.cache_tokens.pos_next(n_keep + n_discard) - kv_keep;
|
|
auto kv_past = slot.cache_tokens.pos_next(slot.n_past);
|
|
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
|
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
|
llama_kv_cache_seq_rm(ctx, slot.id, kv_keep, kv_keep + kv_discard);
|
|
llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard);
|
|
if (slot.has_mtp && slot.spec) {
|
|
common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past);
|
|
}
|
|
if (slot.params.cache_prompt) {
|
|
slot.cache_tokens.discard_n_tokens(n_keep, n_discard);
|
|
}
|
|
}
|
|
|
|
|
|
inline static bool tokens_support_context_shift(const server_tokens & tokens, int32_t n_keep,
|
|
int32_t n_discard) {
|
|
bool can_shift = !tokens.has_mtmd;
|
|
if (tokens.has_mtmd) {
|
|
can_shift = true;
|
|
if (n_keep > 0 && n_keep<= tokens.n_tokens()) {
|
|
can_shift = tokens[n_keep - 1] != LLAMA_TOKEN_NULL;
|
|
}
|
|
if (n_discard + n_keep > 0 && n_discard + n_keep <= tokens.n_tokens()) {
|
|
can_shift = can_shift && tokens[n_discard + n_keep - 1] != LLAMA_TOKEN_NULL;
|
|
}
|
|
}
|
|
return can_shift;
|
|
}
|
|
|
|
inline static void adjust_n_to_support_context_shift(const server_tokens & tokens, int32_t & n_keep,
|
|
int32_t & n_discard) {
|
|
if (!tokens.has_mtmd) {
|
|
return;
|
|
}
|
|
if (n_keep > 0 && n_keep <= tokens.n_tokens()) {
|
|
while (tokens[n_keep - 1] == LLAMA_TOKEN_NULL) {
|
|
n_keep--;
|
|
if (n_keep<1 || n_keep>tokens.size()) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if (n_discard + n_keep > 0 && n_discard + n_keep <= tokens.n_tokens()) {
|
|
while (tokens[n_discard + n_keep - 1] == LLAMA_TOKEN_NULL) {
|
|
n_discard++;
|
|
if (n_discard + n_keep<1 || n_discard + n_keep>tokens.size()) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
|
|
// convert keep first few and discard next tokens in a to b
|
|
void server_context::context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep,
|
|
int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact) {
|
|
|
|
common_prefix ctx_keep_prefix = a.get_common_prefix_first_n(ctx, b, n_keep, exact);
|
|
common_prefix ctx_total_discard_prefix = a.get_common_prefix_first_n(ctx, b, n_discard + n_keep, exact);
|
|
// only if there is enough common token
|
|
int32_t discard_offset = ctx_total_discard_prefix.first - (n_discard + n_keep);
|
|
int32_t keep_offset = ctx_keep_prefix.first - n_keep;
|
|
n_kept = ctx_keep_prefix.second - keep_offset;
|
|
n_discarded = ctx_total_discard_prefix.second - ctx_keep_prefix.second - discard_offset;
|
|
if (n_kept < 0) {
|
|
n_kept = n_keep;
|
|
}
|
|
if (n_discarded < 0) {
|
|
n_discarded = n_discard;
|
|
}
|
|
}
|
|
|
|
void server_context::context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact) {
|
|
int n_keep = std::max(0, slot.params.n_keep + add_bos_token);
|
|
const int n_left = slot.n_ctx - n_keep;
|
|
int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
|
adjust_n_to_support_context_shift(slot.prompt_tokens, n_keep, n_discard);
|
|
if (n_discard<=0 || !tokens_support_context_shift(slot.prompt_tokens, n_keep, n_discard)) {
|
|
return;
|
|
}
|
|
int n_discard_prompt = 0;
|
|
// we still need to truncate input since we have not discarded enough tokens
|
|
while (slot.n_prompt_tokens - slot.n_discarded_prompt >= slot.n_ctx) {
|
|
slot.n_discarded_prompt = slot.n_discarded_prompt + n_discard;
|
|
n_discard_prompt = n_discard_prompt + n_discard;
|
|
}
|
|
|
|
// Handle mistokenization between prompt and cache during context shift
|
|
//
|
|
int32_t n_discard_cache = n_discard_prompt;
|
|
int32_t n_kept = n_keep;
|
|
slot.prompt_tokens.discard_n_tokens(n_keep, slot.n_discarded_prompt - n_discard_prompt);
|
|
if (n_discard_prompt > 0) {
|
|
context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep,
|
|
n_discard, n_kept, n_discard_cache, exact);
|
|
}
|
|
|
|
int n_discard_cache_max = std::max((int32_t)slot.cache_tokens.size() - n_kept, 0);
|
|
n_discard_cache = std::min(n_discard_cache, n_discard_cache_max);
|
|
// discard matching tokens from cache and kv cache to avoid reprocessing the prompt
|
|
if (n_discard_cache > 0) {
|
|
discard_n_kv_and_cache_tokens(ctx, slot, n_kept, n_discard_cache);
|
|
}
|
|
// discard extra tokens from prompts
|
|
slot.n_kept_prompt = n_keep;
|
|
slot.prompt_tokens.discard_n_tokens(n_keep, n_discard_prompt);
|
|
slot.n_prompt_tokens = slot.prompt_tokens.size();
|
|
}
|
|
|
|
void server_context::release_slots()
|
|
{
|
|
for (auto& slot : slots) {
|
|
if (slot.command == SLOT_COMMAND_RELEASE) {
|
|
slot.state = SLOT_STATE_IDLE;
|
|
slot.command = SLOT_COMMAND_NONE;
|
|
slot.t_last_used = ggml_time_us();
|
|
|
|
LOG_INFO("slot released", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_ctx", n_ctx},
|
|
{"n_past", slot.n_past},
|
|
{"n_system_tokens", system_tokens.size()},
|
|
{"n_cache_tokens", slot.cache_tokens.size()},
|
|
{"truncated", slot.truncated}
|
|
});
|
|
|
|
queue_tasks.notify_slot_changed();
|
|
}
|
|
}
|
|
}
|
|
|
|
bool server_context::slots_idle(){
|
|
bool all_idle = true;
|
|
for (auto& slot : slots) {
|
|
if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) {
|
|
all_idle = false;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (all_idle) {
|
|
LOG_INFO("all slots are idle", {});
|
|
if (system_prompt.empty() && clean_kv_cache) {
|
|
kv_cache_clear();
|
|
}
|
|
all_idle = true;
|
|
}
|
|
return all_idle;
|
|
}
|
|
|
|
void server_context::context_shift() {
|
|
for (server_slot& slot : slots) {
|
|
if (slot.ga_n == 1) {
|
|
if (slot.is_processing() && (int)system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
|
|
if (!params_base.ctx_shift) {
|
|
// this check is redundant (for good)
|
|
// we should never get here, because generation should already stopped in process_token()
|
|
slot.print_timings();
|
|
slot.release();
|
|
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
|
|
continue;
|
|
}
|
|
// Shift context
|
|
int n_keep = slot.params.n_keep < 0 ? slot.prompt_tokens.size() : slot.params.n_keep;
|
|
if (add_bos_token) {
|
|
n_keep += 1;
|
|
}
|
|
n_keep = std::min(slot.n_ctx - 4, n_keep);
|
|
|
|
const int32_t n_left = (int)system_tokens.size() + slot.n_past - n_keep;
|
|
int32_t n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
|
int32_t n_kept;
|
|
int32_t n_discard_cache;
|
|
adjust_n_to_support_context_shift(slot.cache_tokens, n_keep, n_discard);
|
|
if (n_discard > 0 && tokens_support_context_shift(slot.cache_tokens, n_keep, n_discard)) {
|
|
context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep,
|
|
n_discard, n_kept, n_discard_cache);
|
|
LOG_INFO("slot context shift", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_keep", n_keep},
|
|
{"n_left", n_left},
|
|
{"n_discard", n_discard},
|
|
{"n_ctx", n_ctx},
|
|
{"n_past", slot.n_past},
|
|
{"n_system_tokens", system_tokens.size()},
|
|
{"n_cache_tokens", slot.cache_tokens.size()}
|
|
});
|
|
slot.n_discarded_prompt = slot.n_discarded_prompt + n_discard;
|
|
slot.n_kept_prompt = n_keep;
|
|
discard_n_kv_and_cache_tokens(ctx, slot, n_kept, n_discard_cache);
|
|
slot.n_past -= n_discard_cache;
|
|
slot.truncated = true;
|
|
}
|
|
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void server_context::add_sampled_tokens() {
|
|
for (auto& slot : slots) {
|
|
slot.released = false;
|
|
if (slot.state == SLOT_STATE_IDLE) {
|
|
continue;
|
|
}
|
|
|
|
// generate draft tokens in speculative decoding mode
|
|
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
|
// perform the speculative drafting for all sequences at the same time in a single batch
|
|
const int n_draft_max_pre = slot.get_n_draft_max();
|
|
if (n_draft_max_pre > 0) {
|
|
if (mctx) {
|
|
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
|
GGML_ABORT("not supported by multimodal");
|
|
}
|
|
|
|
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
|
|
|
|
auto & params_spec = slot.params.speculative;
|
|
|
|
if (slot.has_mtp) {
|
|
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
|
|
llama_context * hs_ctx = mtp_ctx ? mtp_ctx : ctx;
|
|
if (!slot.mtp_hidden_state.empty()) {
|
|
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
|
const int n_hidden = slot.mtp_hidden_state.size() / n_embd;
|
|
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd);
|
|
} else {
|
|
LOG_ERROR("MTP hidden state is empty during speculation", {});
|
|
const float* emb_neg1 = llama_get_embeddings_ith(ctx, -1);
|
|
if (emb_neg1) {
|
|
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
|
slot.mtp_hidden_state.resize(n_embd);
|
|
memcpy(slot.mtp_hidden_state.data(), emb_neg1, n_embd * sizeof(float));
|
|
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data());
|
|
}
|
|
}
|
|
}
|
|
|
|
llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
|
|
|
|
const int n_draft_max = slot.get_n_draft_max();
|
|
|
|
if (draft.size() > (size_t)n_draft_max) {
|
|
if (slot.params.speculative.autotune) {
|
|
// expected near end-of-response when autotune shrinks n_max
|
|
SLT_DBG(slot, "draft size %d exceeds max %d, truncating\n", (int)draft.size(), n_draft_max);
|
|
} else {
|
|
SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int)draft.size(), n_draft_max);
|
|
}
|
|
draft.resize(n_draft_max);
|
|
}
|
|
|
|
// add the sampled token to the batch
|
|
slot.i_batch_dft.push_back(batch.n_tokens);
|
|
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
|
slot.cache_tokens.push_back(slot.sampled);
|
|
|
|
if (slot.params.speculative.n_min > (int)draft.size()) {
|
|
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min);
|
|
// fallback to normal decoding
|
|
slot.i_batch = slot.i_batch_dft[0];
|
|
slot.drafted.clear();
|
|
slot.i_batch_dft.clear();
|
|
}
|
|
else {
|
|
// keep track of total number of drafted tokens tested
|
|
slot.n_draft_total += draft.size();
|
|
|
|
// add all drafted tokens to the batch
|
|
for (size_t i = 0; i < draft.size(); i++) {
|
|
slot.i_batch_dft.push_back(batch.n_tokens);
|
|
common_batch_add(batch, draft[i], slot.cache_tokens.pos_next(), { slot.id }, true);
|
|
slot.cache_tokens.push_back(draft[i]);
|
|
}
|
|
slot.drafted = std::move(draft);
|
|
}
|
|
}
|
|
else {
|
|
// no speculative decoding
|
|
slot.i_batch = batch.n_tokens;
|
|
|
|
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
|
|
|
slot.cache_tokens.push_back(slot.sampled);
|
|
|
|
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
|
|
(int)slot.n_ctx, (int)slot.cache_tokens.size(), (int)slot.truncated);
|
|
}
|
|
slot.n_past = slot.cache_tokens.n_tokens();
|
|
}
|
|
}
|
|
|
|
void server_context::create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base) {
|
|
if (params_base.do_checkpoint && params_base.ctx_checkpoints_interval > 0) {
|
|
auto pos = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
|
if (slot.checkpoint_pos + params_base.ctx_checkpoints_interval <= 1 + pos) {
|
|
bool created = create_checkpoint(slot);
|
|
if (created) {
|
|
slot.checkpoint_pos = pos;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void server_context::apply_checkpoint(server_slot & slot) {
|
|
llama_pos pos_next = slot.cache_tokens.pos_next(slot.n_past);
|
|
const auto pos_min_thold = std::max(0, pos_next - 1);
|
|
if (slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
|
|
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
|
|
|
if (pos_min > pos_min_thold) {
|
|
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int)slot.cache_tokens.size(), slot.id, pos_min);
|
|
|
|
// search for a context checkpoint
|
|
const auto it = std::find_if(
|
|
slot.server_cached_prompt.checkpoints.rbegin(),
|
|
slot.server_cached_prompt.checkpoints.rend(),
|
|
[&](const auto & cur) {
|
|
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
|
|
return cur.pos_min < pos_min_thold;
|
|
}
|
|
);
|
|
|
|
bool do_reset = it == slot.server_cached_prompt.checkpoints.rend();
|
|
|
|
if (!do_reset) {
|
|
// restore the context checkpoint
|
|
const int64_t t_start = ggml_time_us();
|
|
const size_t checkpoint_size = it->data.size();
|
|
const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
if (n != checkpoint_size) {
|
|
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
|
do_reset = true;
|
|
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
|
} else {
|
|
slot.n_past = std::min(slot.n_past, std::max(it->pos_min+1, it->pos_max));
|
|
slot.n_past = slot.cache_tokens.size_up_to_pos(slot.n_past-1);
|
|
slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt+1, it->pos_max_prompt));
|
|
slot.n_past_prompt = slot.prompt_tokens.size_up_to_pos(slot.n_past_prompt-1);
|
|
SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
|
}
|
|
}
|
|
|
|
if (do_reset) {
|
|
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
|
|
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
|
slot.n_past = 0;
|
|
slot.n_past_prompt = 0;
|
|
}
|
|
}
|
|
}
|
|
|
|
{
|
|
// erase any checkpoints with pos_min > pos_min_thold
|
|
for (auto it = slot.server_cached_prompt.checkpoints.begin(); it != slot.server_cached_prompt.checkpoints.end();) {
|
|
const auto & cur = *it;
|
|
if (cur.pos_min > pos_min_thold) {
|
|
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
|
it = slot.server_cached_prompt.checkpoints.erase(it);
|
|
} else {
|
|
++it;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
bool server_context::create_checkpoint(server_slot & slot) {
|
|
bool do_checkpoint = !slot.image_just_processed;
|
|
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
|
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
|
|
|
// no need for empty or small checkpoints
|
|
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 16);
|
|
|
|
// no need to create checkpoints that are too close together
|
|
do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max);
|
|
|
|
if (do_checkpoint) {
|
|
const int64_t t_start = ggml_time_us();
|
|
while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.ctx_checkpoints_n) {
|
|
// make room for the new checkpoint, if needed
|
|
const auto & cur = slot.server_cached_prompt.checkpoints.front();
|
|
|
|
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
|
cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
|
|
|
slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
|
|
}
|
|
|
|
const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
|
/*.pos_min = */ pos_min,
|
|
/*.pos_max = */ pos_max,
|
|
/*.pos_min_prompt = */ pos_min + slot.n_past_offset,
|
|
/*.pos_max_prompt = */ pos_max + slot.n_past_offset ,
|
|
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
|
});
|
|
|
|
llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB, took %.2f ms)\n",
|
|
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024,
|
|
(ggml_time_us() - t_start) / 1000.0);
|
|
}
|
|
return do_checkpoint;
|
|
}
|
|
|
|
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
|
|
if (params_base.cont_batching || batch.n_tokens == 0) {
|
|
for (auto& slot : slots) {
|
|
// this slot still has a prompt to be processed
|
|
if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) {
|
|
auto& prompt_tokens = slot.prompt_tokens;
|
|
|
|
// we haven't tokenized the prompt yet - do it now:
|
|
if (prompt_tokens.empty() || slot.n_prompt_tokens == 0) {
|
|
LOG_VERBOSE("tokenizing prompt", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task}
|
|
});
|
|
|
|
slot.t_start_process_prompt = ggml_time_us();
|
|
slot.t_start_generation = 0;
|
|
|
|
if (slot.infill) {
|
|
const bool add_bos = llama_should_add_bos_token(model);
|
|
bool suff_rm_leading_spc = true;
|
|
if (params_base.input_suffix.find_first_of(' ') == 0 && params_base.input_suffix.size() > 1) {
|
|
params_base.input_suffix.erase(0, 1);
|
|
suff_rm_leading_spc = false;
|
|
}
|
|
|
|
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
|
|
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
|
|
|
|
const int space_token = 29871; // TODO: this should not be hardcoded
|
|
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
|
|
suffix_tokens.erase(suffix_tokens.begin());
|
|
}
|
|
|
|
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
|
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
|
|
|
|
auto embd_inp = params_base.spm_infill ? suffix_tokens : prefix_tokens;
|
|
auto embd_end = params_base.spm_infill ? prefix_tokens : suffix_tokens;
|
|
if (add_bos) {
|
|
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
|
}
|
|
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
|
|
|
const llama_token middle_token = llama_token_middle(model);
|
|
if (middle_token >= 0) {
|
|
embd_inp.push_back(middle_token);
|
|
}
|
|
|
|
prompt_tokens = server_tokens(embd_inp, false);
|
|
}
|
|
else {
|
|
// prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
|
}
|
|
|
|
slot.n_past = 0;
|
|
slot.n_prompt_tokens = prompt_tokens.size();
|
|
|
|
LOG_VERBOSE("prompt tokenized", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_ctx", slot.n_ctx},
|
|
{"n_keep", slot.params.n_keep},
|
|
{"n_prompt_tokens", slot.n_prompt_tokens},
|
|
{"prompt_tokens", prompt_tokens.detokenize(ctx, true)},
|
|
});
|
|
|
|
// empty prompt passed -> release the slot and send empty response
|
|
if (prompt_tokens.empty()) {
|
|
LOG_INFO("empty prompt - releasing slot", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task}
|
|
});
|
|
|
|
slot.state = SLOT_STATE_PROCESSING;
|
|
slot.command = SLOT_COMMAND_NONE;
|
|
send_final_response(slot);
|
|
slot.release();
|
|
slot.print_timings();
|
|
continue;
|
|
}
|
|
|
|
if (slot.embedding) {
|
|
// this prompt is too large to process - discard it
|
|
if (slot.n_prompt_tokens > n_ubatch) {
|
|
slot.state = SLOT_STATE_PROCESSING;
|
|
slot.command = SLOT_COMMAND_NONE;
|
|
slot.release();
|
|
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
|
continue;
|
|
}
|
|
}
|
|
else {
|
|
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
|
// context shift for prompt processing
|
|
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
|
if (!params_base.ctx_shift) {
|
|
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER);
|
|
slot.release();
|
|
continue;
|
|
}
|
|
context_shift_prompt(ctx, slot);
|
|
slot.truncated = true;
|
|
LOG_VERBOSE("input truncated", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_ctx", slot.n_ctx},
|
|
{"n_keep", slot.params.n_keep},
|
|
{"n_left", slot.n_ctx - slot.params.n_keep},
|
|
{"n_prompt_tokens", slot.n_prompt_tokens},
|
|
{"prompt_tokens", prompt_tokens.detokenize(ctx, true)},
|
|
});
|
|
|
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
|
|
|
#ifndef NDEBUG
|
|
// debug
|
|
common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false);
|
|
int32_t back = 1;
|
|
if (slot.cache_tokens.size() && slot.cache_tokens.size() > prefix.first + 20
|
|
&& prefix.second >= back && prefix.first >= back) {
|
|
LLAMA_LOG_INFO("After context shift :\n");
|
|
print_tokens(slot.prompt_tokens, slot.cache_tokens, prefix.second - back, prefix.first - back, 50);
|
|
}
|
|
#endif
|
|
}
|
|
else {
|
|
slot.n_discarded_prompt = 0;
|
|
}
|
|
common_sampler_reset(slot.ctx_sampling);
|
|
|
|
if (!slot.params.cache_prompt) {
|
|
slot.n_past_se = 0;
|
|
slot.ga_i = 0;
|
|
}
|
|
else {
|
|
GGML_ASSERT(slot.ga_n == 1);
|
|
|
|
// reuse any previously computed tokens that are common with the new prompt
|
|
common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, true); // string level match
|
|
common_prefix prefix_nonexact = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false);
|
|
auto n_past0 = slot.cache_tokens.get_common_prefix_exact(prompt_tokens); // token level match
|
|
LLAMA_LOG_INFO("======== Cache: cache_size = %d, n_past0 = %d, n_past1 = %d, n_past_prompt1 = %d, n_past2 = %d, n_past_prompt2 = %d\n", (int32_t)slot.cache_tokens.size(), (int32_t)n_past0, (int32_t)prefix.first, (int32_t)prefix.second, (int32_t)prefix_nonexact.first, (int32_t)prefix_nonexact.second);
|
|
int32_t size_threshold = 20;
|
|
if (prefix.first + size_threshold < prefix_nonexact.first) {
|
|
LLAMA_LOG_WARN("Common part contains missing or extra space and new line\n");
|
|
prefix = prefix_nonexact;
|
|
}
|
|
slot.n_past = prefix.first;
|
|
slot.n_past_prompt = prefix.second;
|
|
slot.n_past_offset = slot.n_past_prompt - slot.n_past;
|
|
|
|
if (slot.n_past != slot.n_past_prompt) {
|
|
LLAMA_LOG_INFO("Mistokenization found and handled successfully.\n");
|
|
}
|
|
if ((slot.n_past + size_threshold < slot.cache_tokens.size()))
|
|
{
|
|
LLAMA_LOG_WARN("Common part does not match fully\n");
|
|
int32_t back = 4;
|
|
if (prefix.second >= back && prefix.first >= back) {
|
|
print_tokens(slot.prompt_tokens, slot.cache_tokens, prefix.second - back, prefix.first - back, 30);
|
|
}
|
|
}
|
|
|
|
// push the prompt into the sampling context (do not apply grammar)
|
|
for (int i = 0; i < slot.n_past; ++i) {
|
|
common_sampler_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
|
}
|
|
}
|
|
}
|
|
apply_checkpoint(slot);
|
|
if (slot.n_past_prompt == slot.n_prompt_tokens && slot.n_past_prompt > 0) {
|
|
// we have to evaluate at least 1 token to generate logits.
|
|
LOG_INFO("we have to evaluate at least 1 token to generate logits", {
|
|
{ "id_slot", slot.id },
|
|
{ "id_task", slot.id_task }
|
|
});
|
|
|
|
slot.n_past_prompt--;
|
|
slot.n_past--;
|
|
if (slot.ga_i > 0) {
|
|
slot.n_past_se--;
|
|
}
|
|
}
|
|
|
|
slot.n_prompt_tokens_processed = 0;
|
|
}
|
|
|
|
if (slot.embedding) {
|
|
// cannot fit the prompt in the current batch - will try next iter
|
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// check that we are in the right batch_type, if not defer the slot
|
|
bool slot_type = slot.embedding ? 1 : 0;
|
|
if (batch_type == -1) {
|
|
batch_type = slot_type;
|
|
}
|
|
else if (batch_type != slot_type) {
|
|
continue;
|
|
}
|
|
|
|
// keep only the common part
|
|
// remove the non-common part from the cache
|
|
if (slot.n_past < 0)
|
|
{
|
|
slot.n_past = 0;
|
|
}
|
|
slot.cache_tokens.keep_first(slot.n_past);
|
|
int p0 = (int)system_tokens.size() + slot.n_past;
|
|
p0 = system_tokens.size() + slot.cache_tokens.pos_next();
|
|
if (!llama_kv_cache_seq_rm(ctx, slot.id, p0, -1)) {
|
|
// could not partially delete (likely using a non-Transformer model)
|
|
llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
|
|
|
|
p0 = (int)system_tokens.size();
|
|
if (p0 != 0) {
|
|
// copy over the system prompt when there is one
|
|
llama_kv_cache_seq_cp(ctx, 0, slot.id, -1, -1);
|
|
}
|
|
|
|
// there is no common part left (except for the system prompt)
|
|
slot.n_past = 0;
|
|
slot.n_past_se = 0;
|
|
slot.ga_i = 0;
|
|
// TODO: is the system prompt ever in the sampling context?
|
|
common_sampler_reset(slot.ctx_sampling);
|
|
}
|
|
|
|
LOG_INFO("kv cache rm [p0, end)", {
|
|
{ "id_slot", slot.id },
|
|
{ "id_task", slot.id_task },
|
|
{ "p0", p0 }
|
|
});
|
|
|
|
// check if we should process the image
|
|
if (slot.n_past_prompt < slot.n_prompt_tokens
|
|
&& slot.prompt_tokens[slot.n_past_prompt] == LLAMA_TOKEN_NULL) {
|
|
// process the image
|
|
size_t n_tokens_out = 0;
|
|
llama_pos p1 = slot.cache_tokens.pos_next() + slot.n_past_prompt - slot.n_past; // add offset to prompt
|
|
int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past_prompt, p1, slot.id, n_tokens_out);
|
|
if (res != 0) {
|
|
LLAMA_LOG_ERROR("failed to process image, res = %d\n", res);
|
|
slot.release();
|
|
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
|
|
continue;
|
|
}
|
|
|
|
// add the image chunk to cache
|
|
{
|
|
const auto& chunk = slot.prompt_tokens.find_chunk(slot.n_past_prompt);
|
|
slot.cache_tokens.push_back(chunk.get()); // copy
|
|
}
|
|
|
|
slot.n_past += n_tokens_out;
|
|
slot.n_past_prompt += n_tokens_out;
|
|
slot.n_prompt_tokens_processed += n_tokens_out;
|
|
slot.image_just_processed = true; // do not checkpoint right after an image chunk
|
|
}
|
|
|
|
|
|
|
|
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
|
|
|
int32_t ga_i = slot.ga_i;
|
|
int32_t ga_n = slot.ga_n;
|
|
int32_t ga_w = slot.ga_w;
|
|
|
|
// add prompt tokens for processing in the current batch
|
|
// TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
|
|
while (slot.n_past_prompt < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
|
// get next token to process
|
|
llama_token cur_tok = slot.prompt_tokens[slot.n_past_prompt];
|
|
if (cur_tok == LLAMA_TOKEN_NULL) {
|
|
break; // end of text chunk
|
|
}
|
|
if (slot.ga_n != 1) {
|
|
while (slot_npast >= ga_i + ga_w) {
|
|
const int bd = (ga_w / ga_n) * (ga_n - 1);
|
|
slot_npast -= bd;
|
|
ga_i += ga_w / ga_n;
|
|
}
|
|
}
|
|
|
|
int p0 = system_tokens.size() + slot.cache_tokens.pos_next();
|
|
common_batch_add(batch, cur_tok, p0, { slot.id }, slot.need_embd());
|
|
|
|
slot.cache_tokens.push_back(cur_tok);
|
|
|
|
|
|
slot.n_prompt_tokens_processed++;
|
|
slot_npast++;
|
|
slot.n_past_prompt++;
|
|
slot.n_past++;
|
|
slot.image_just_processed = false;
|
|
if (params_base.do_checkpoint && slot.n_prompt_tokens - slot.n_past_prompt == params_base.ctx_checkpoints_tolerance) {
|
|
slot.do_checkpoint = true;
|
|
break;
|
|
}
|
|
|
|
}
|
|
LOG_VERBOSE("prompt processing progress", {
|
|
{"id_slot", slot.id},
|
|
{"n_past", slot.n_past},
|
|
{"n_ctx", n_ctx},
|
|
{"n_tokens", batch.n_tokens},
|
|
{"progress", (float)slot.n_prompt_tokens_processed / slot.n_prompt_tokens},
|
|
});
|
|
|
|
// entire prompt has been processed - start decoding new tokens
|
|
if (slot.n_past_prompt == slot.n_prompt_tokens) {
|
|
slot.state = SLOT_STATE_PROCESSING;
|
|
slot.command = SLOT_COMMAND_NONE;
|
|
GGML_ASSERT(batch.n_tokens > 0);
|
|
GGML_ASSERT((size_t)slot.n_prompt_tokens == slot.prompt_tokens.size());
|
|
common_sampler_reset(slot.ctx_sampling);
|
|
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
|
llama_token id = slot.prompt_tokens[i];
|
|
if (id != LLAMA_TOKEN_NULL) {
|
|
common_sampler_accept(slot.ctx_sampling, ctx, id, false);
|
|
}
|
|
}
|
|
|
|
// extract the logits only for the last token
|
|
batch.logits[batch.n_tokens - 1] = true;
|
|
|
|
slot.n_decoded = 0;
|
|
slot.i_batch = batch.n_tokens - 1;
|
|
|
|
LOG_VERBOSE("prompt done", {
|
|
{"id_slot", slot.id},
|
|
{"n_past", slot.n_past},
|
|
{"n_ctx", n_ctx},
|
|
{"n_tokens", batch.n_tokens},
|
|
});
|
|
}
|
|
}
|
|
|
|
if (batch.n_tokens >= n_batch) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void server_context::extend_context(const int32_t n_tokens) {
|
|
for (auto& slot : slots) {
|
|
if (slot.ga_n != 1) {
|
|
// context extension via Self-Extend
|
|
// TODO: simplify and/or abstract this
|
|
while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
|
|
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
|
|
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
|
|
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
|
|
|
|
LOG_TEE("\n");
|
|
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
|
|
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
|
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
|
|
|
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
|
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
|
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
|
|
|
slot.n_past_se -= bd;
|
|
|
|
slot.ga_i += slot.ga_w / slot.ga_n;
|
|
|
|
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
|
|
}
|
|
|
|
slot.n_past_se += n_tokens;
|
|
}
|
|
}
|
|
}
|
|
|
|
void server_context::speculative_decoding_accept() {
|
|
for (auto& slot : slots) {
|
|
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
|
|
continue;
|
|
}
|
|
|
|
size_t n_draft = slot.drafted.size();
|
|
|
|
apply_server_biases(slot);
|
|
|
|
// the accepted tokens from the speculation
|
|
const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
|
|
|
|
if (slot.has_mtp) {
|
|
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
|
|
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
|
|
|
|
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
|
if (!ids.empty()) {
|
|
const float* emb = llama_get_embeddings(ctx);
|
|
if (emb) {
|
|
slot.mtp_hidden_state.resize(ids.size() * n_embd);
|
|
memcpy(slot.mtp_hidden_state.data(), emb, ids.size() * n_embd * sizeof(float));
|
|
}
|
|
} else {
|
|
const float* emb0 = llama_get_embeddings_ith(ctx, 0);
|
|
if (emb0) {
|
|
slot.mtp_hidden_state.resize(n_embd);
|
|
memcpy(slot.mtp_hidden_state.data(), emb0, n_embd * sizeof(float));
|
|
}
|
|
}
|
|
|
|
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data());
|
|
|
|
int32_t n_past_base = slot.n_past - (slot.drafted.size() + 1);
|
|
mtp_accept_tokens(mtp_target, ids, n_past_base, slot.id);
|
|
}
|
|
|
|
slot.i_batch_dft.clear();
|
|
slot.drafted.clear();
|
|
|
|
slot.n_past += ids.size();
|
|
slot.n_decoded += ids.size();
|
|
const int64_t t_current = ggml_time_us();
|
|
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
|
|
|
// update how many tokens out of those tested were accepted
|
|
slot.n_draft_accepted += ids.size() - 1;
|
|
|
|
// inform the speculative decoding about the number of accepted tokens
|
|
common_speculative_accept(slot.spec, ids.size() - 1);
|
|
|
|
// rollback to the state before sampling the draft tokens
|
|
slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft);
|
|
|
|
// add accepted tokens to the prompt
|
|
slot.cache_tokens.insert({ ids.begin(), ids.end() - 1 });
|
|
slot.sampled = ids.back(); // last accepted token
|
|
slot.n_past = slot.cache_tokens.n_tokens();
|
|
|
|
// for hybrid models: if any drafts were rejected, restore recurrent state
|
|
const bool any_rejected = (ids.size() - 1) < n_draft;
|
|
if (any_rejected && slot.spec_ckpt_valid) {
|
|
llama_state_seq_set_data(ctx, slot.spec_ckpt_data.data(), slot.spec_ckpt_data.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
llama_kv_cache_seq_rm(ctx, slot.id, slot.spec_ckpt_n_past, -1);
|
|
|
|
// restore sampler state (RNG, grammar, prev tokens)
|
|
if (slot.spec_ckpt_sampler) {
|
|
common_sampler_clone(slot.spec_ckpt_sampler, slot.ctx_sampling);
|
|
common_sampler_free(slot.spec_ckpt_sampler);
|
|
slot.spec_ckpt_sampler = nullptr;
|
|
}
|
|
|
|
if (!ids.empty()) {
|
|
const int n_accepted = (int)ids.size();
|
|
llama_batch re_batch = llama_batch_init(n_accepted, 0, 1);
|
|
for (int j = 0; j < n_accepted; j++) {
|
|
const bool is_last = (j == n_accepted - 1);
|
|
common_batch_add(re_batch, ids[j], slot.spec_ckpt_n_past + j, { slot.id }, is_last);
|
|
}
|
|
|
|
if (slot.has_mtp) {
|
|
llama_set_embeddings(ctx, true);
|
|
}
|
|
|
|
const int ret = llama_decode(ctx, re_batch);
|
|
if (ret != 0) {
|
|
SLT_ERR(slot, "failed to re-decode accepted tokens after checkpoint restore: %d\n", ret);
|
|
}
|
|
|
|
if (slot.has_mtp) {
|
|
llama_set_embeddings(ctx, false);
|
|
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
|
const float * emb = llama_get_embeddings_ith(ctx, -1);
|
|
if (emb) {
|
|
slot.mtp_hidden_state.resize(n_embd);
|
|
memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float));
|
|
}
|
|
}
|
|
|
|
llama_batch_free(re_batch);
|
|
SLT_DBG(slot, "spec checkpoint restored: re-decoded %d accepted tokens (rejected %d)\n",
|
|
n_accepted, (int)(n_draft - (ids.size() - 1)));
|
|
}
|
|
|
|
slot.spec_ckpt_valid = false;
|
|
} else {
|
|
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
|
|
slot.spec_ckpt_valid = false;
|
|
// discard saved sampler on full acceptance
|
|
if (slot.spec_ckpt_sampler) {
|
|
common_sampler_free(slot.spec_ckpt_sampler);
|
|
slot.spec_ckpt_sampler = nullptr;
|
|
}
|
|
}
|
|
|
|
for (size_t i = 0; i < ids.size(); ++i) {
|
|
completion_token_output result;
|
|
|
|
result.tok = ids[i];
|
|
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
|
result.prob = 1.0f; // set later
|
|
|
|
if (slot.sparams.n_probs > 0) {
|
|
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, i);
|
|
}
|
|
|
|
if (slot.n_buffer == 0 || !params_base.can_ban_phrases) {
|
|
if (!process_token(result, slot)) {
|
|
// release slot because of stop condition
|
|
slot.cache_tokens.push_back(slot.sampled);
|
|
slot.n_past++;
|
|
send_final_response(slot);
|
|
release_slot_after_final_response(slot);
|
|
break;
|
|
}
|
|
} else {
|
|
buffer_and_check_string_ban(slot, result);
|
|
if (slot.task == nullptr) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
common_sampler_review(slot.ctx_sampling, slot.token_buffer.size(), slot.rewind_status);
|
|
|
|
update_allowlist_state(slot);
|
|
}
|
|
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past);
|
|
LOG_VERBOSE("speculative decoding result", {
|
|
{"id_slot", slot.id},
|
|
{"accepted", (int)ids.size() - 1},
|
|
{"total", (int)slot.drafted.size()},
|
|
{"new_n_past", slot.n_past}
|
|
});
|
|
}
|
|
}
|
|
|
|
|
|
bool server_context::accept_special_token(const server_slot& slot, const llama_token token) {
|
|
return params_base.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end();
|
|
}
|
|
|
|
void server_context::release_slot_after_final_response(server_slot & slot) {
|
|
slot.print_timings();
|
|
if (params_base.do_checkpoint) {
|
|
create_checkpoint(slot);
|
|
}
|
|
slot.release();
|
|
slot.released = true;
|
|
metrics.on_prediction(slot);
|
|
}
|
|
|
|
void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) {
|
|
int count = 0;
|
|
bool released = false;
|
|
|
|
int32_t start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
|
|
|
|
for (auto& it : results) {
|
|
bool has_next = process_token(it, slot);
|
|
|
|
// Clean up positional bans for the token we just confirmed/sent
|
|
slot.positional_bans.erase(start_pos + count);
|
|
|
|
count++;
|
|
if (!has_next) {
|
|
if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
|
|
continue;
|
|
}
|
|
slot.cache_tokens.push_back(slot.sampled);
|
|
slot.n_past++;
|
|
send_final_response(slot);
|
|
release_slot_after_final_response(slot);
|
|
released = true;
|
|
break;
|
|
}
|
|
if (n > 0 && count >= n) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (!released && slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
|
|
slot.cache_tokens.push_back(slot.sampled);
|
|
slot.n_past++;
|
|
send_final_response(slot);
|
|
release_slot_after_final_response(slot);
|
|
}
|
|
|
|
if (count > 0) {
|
|
slot.sampled = results[results.size()-1].tok;
|
|
results.erase(results.begin(), results.begin() + count);
|
|
}
|
|
|
|
}
|
|
|
|
inline int32_t check_ban_phrase(server_slot& slot) {
|
|
if (slot.token_buffer.empty()) return -1;
|
|
|
|
std::string string_buffer;
|
|
std::vector<size_t> token_offsets;
|
|
|
|
for (const auto& it : slot.token_buffer) {
|
|
token_offsets.push_back(string_buffer.size());
|
|
string_buffer += it.text_to_send;
|
|
}
|
|
|
|
size_t best_start = std::string::npos;
|
|
bool found = false;
|
|
std::string string_buffer_lower = string_lower(string_buffer);
|
|
|
|
// 1. Check strings
|
|
for (const auto& phrase : slot.ban_phrases) {
|
|
size_t start = string_buffer_lower.find(phrase);
|
|
if (start != std::string::npos) {
|
|
if (start < best_start) {
|
|
best_start = start;
|
|
found = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
// 2. Check regex
|
|
for (const auto& pattern : slot.ban_regex) {
|
|
try {
|
|
std::regex re(pattern);
|
|
std::smatch match;
|
|
if (std::regex_search(string_buffer, match, re)) {
|
|
if (match.position() < best_start) {
|
|
best_start = match.position();
|
|
found = true;
|
|
}
|
|
}
|
|
} catch (...) { continue; }
|
|
}
|
|
|
|
// 3. Check regex case insensitive
|
|
for (const auto& pattern : slot.ban_regex_ci) {
|
|
try {
|
|
std::regex re(pattern, std::regex_constants::icase);
|
|
std::smatch match;
|
|
if (std::regex_search(string_buffer, match, re)) {
|
|
if (match.position() < best_start) {
|
|
best_start = match.position();
|
|
found = true;
|
|
}
|
|
}
|
|
} catch (...) { continue; }
|
|
}
|
|
|
|
if (found) {
|
|
int32_t token_idx = -1;
|
|
for (size_t i = 0; i < token_offsets.size(); ++i) {
|
|
size_t len = (i == token_offsets.size() - 1)
|
|
? string_buffer.size() - token_offsets[i]
|
|
: token_offsets[i+1] - token_offsets[i];
|
|
|
|
if (best_start >= token_offsets[i] && best_start < token_offsets[i] + len) {
|
|
token_idx = (int32_t)i;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (token_idx != -1) {
|
|
int32_t abs_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1 + token_idx;
|
|
return abs_pos;
|
|
}
|
|
}
|
|
|
|
return -1;
|
|
}
|
|
|
|
inline void rewind_context(server_slot& slot, int32_t ban_pos) {
|
|
slot.rewind_count++;
|
|
|
|
int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
|
|
int32_t n_keep_buffer = ban_pos - buffer_start_pos;
|
|
if (n_keep_buffer < 0) n_keep_buffer = 0;
|
|
|
|
if (slot.banned_n != 0) {
|
|
int32_t n = 0;
|
|
for (auto result = slot.token_buffer.begin() + n_keep_buffer; result != slot.token_buffer.end(); result++) {
|
|
llama_token banned_tok = result->tok;
|
|
|
|
if (n == 0) {
|
|
LLAMA_LOG_DEBUG("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n",
|
|
ban_pos, banned_tok, result->text_to_send.c_str());
|
|
}
|
|
|
|
slot.positional_bans[ban_pos].insert(banned_tok);
|
|
n++;
|
|
if (slot.banned_n > 0 && n == slot.banned_n) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
int32_t n_rewind_total = (slot.n_past + 1) - ban_pos;
|
|
|
|
size_t n_keep_cache = 0;
|
|
if (ban_pos > 0) {
|
|
n_keep_cache = (size_t)(ban_pos - 1);
|
|
}
|
|
|
|
if (n_keep_cache > slot.cache_tokens.size()) {
|
|
n_keep_cache = slot.cache_tokens.size();
|
|
}
|
|
|
|
if (n_keep_cache < slot.cache_tokens.size()) {
|
|
slot.sampled = slot.cache_tokens[n_keep_cache];
|
|
} else {
|
|
slot.sampled = 0;
|
|
}
|
|
|
|
// Truncate cache
|
|
slot.cache_tokens.keep_first(n_keep_cache);
|
|
slot.n_past = slot.cache_tokens.n_tokens();
|
|
|
|
// Remove from KV cache
|
|
llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.n_past, -1);
|
|
|
|
// Truncate buffer
|
|
slot.token_buffer.resize(n_keep_buffer);
|
|
|
|
// Adjust decoded count
|
|
if (slot.saturate_predict) {
|
|
slot.n_decoded -= n_rewind_total;
|
|
if (slot.n_decoded < 0) slot.n_decoded = 0;
|
|
}
|
|
}
|
|
|
|
void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) {
|
|
slot.token_buffer.push_back(result);
|
|
|
|
bool next_token = has_next_token(result, slot);
|
|
// If buffer full or generation stopped, we might send tokens
|
|
bool buffer_full = slot.token_buffer.size() >= slot.n_buffer;
|
|
|
|
int32_t ban_pos = -1;
|
|
bool sent_results = false;
|
|
|
|
// Always reset logit bias to base before checking bans
|
|
slot.ctx_sampling->params.logit_bias = slot.logit_bias;
|
|
|
|
if (slot.ban_phrases.size() > 0 || slot.ban_regex.size() > 0 || slot.ban_regex_ci.size() > 0) {
|
|
ban_pos = check_ban_phrase(slot);
|
|
}
|
|
|
|
bool allow_rewind = true;
|
|
|
|
if (ban_pos >= 0) {
|
|
if (slot.rewind_count_max == -1) {
|
|
// Automatic / Heuristic logic
|
|
// Account for strings + regex + regex_ci
|
|
size_t total_bans = slot.ban_phrases.size() + slot.ban_regex.size() + slot.ban_regex_ci.size();
|
|
|
|
// Heuristic: Allow if under 20 OR under 2 * total_bans
|
|
// Conversely: Stop if >= 20 AND > 2 * total_bans
|
|
if (slot.rewind_count >= 20 && slot.rewind_count > 2 * total_bans) {
|
|
allow_rewind = false;
|
|
}
|
|
}
|
|
else if (slot.rewind_count_max > 0) {
|
|
// Strict limit logic
|
|
if (slot.rewind_count >= slot.rewind_count_max) {
|
|
allow_rewind = false;
|
|
}
|
|
}
|
|
// If slot.rewind_count_max == 0, allow_rewind remains true (Infinite)
|
|
}
|
|
|
|
if (ban_pos >= 0 && allow_rewind) {
|
|
rewind_context(slot, ban_pos);
|
|
slot.rewind_status = true;
|
|
}
|
|
else if (buffer_full || !next_token) {
|
|
slot.rewind_status = false;
|
|
slot.rewind_count = 0;
|
|
|
|
if (!next_token) {
|
|
// send all remaining tokens
|
|
send_token_results(slot.token_buffer, slot);
|
|
}
|
|
else {
|
|
// send 1 token from the front (FIFO)
|
|
send_token_results(slot.token_buffer, slot, 1);
|
|
}
|
|
}
|
|
else {
|
|
// buffer the result, wait for more tokens to validate string
|
|
slot.sampled = result.tok;
|
|
}
|
|
}
|
|
|
|
void server_context::update_allowlist_state(server_slot& slot) {
|
|
const auto& kws = slot.allow_kws;
|
|
auto& idx = slot.allow_idx;
|
|
if ((slot.allow_kw_delay > slot.n_decoded) || (idx >= kws.size())) {
|
|
return;
|
|
}
|
|
|
|
// search for keyword
|
|
auto kw = kws[idx];
|
|
auto pos = slot.generated_text.find(kw, std::max(0, slot.last_gentxt_size - (int32_t)kw.length() + 1));
|
|
while (pos != std::string::npos) {
|
|
if (++idx >= kws.size()) {
|
|
break;
|
|
}
|
|
kw = kws[idx];
|
|
pos = slot.generated_text.find(kw, pos + 1);
|
|
}
|
|
}
|
|
|
|
void server_context::process_batch_tokens(int32_t & n_batch) {
|
|
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
|
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
|
extend_context(n_tokens);
|
|
|
|
llama_batch batch_view = {
|
|
n_tokens,
|
|
batch.token + i,
|
|
nullptr,
|
|
batch.pos + i,
|
|
batch.n_seq_id + i,
|
|
batch.seq_id + i,
|
|
batch.logits + i,
|
|
0, 0, 0, // unused
|
|
};
|
|
|
|
const int ret = llama_decode(ctx, batch_view);
|
|
if (ret != 0) {
|
|
if (n_batch == 1 || ret < 0) {
|
|
int user_cancel = -3;
|
|
if (ret == user_cancel) {
|
|
LLAMA_LOG_INFO("Decode process is cancelled by user.\n");
|
|
}
|
|
else {
|
|
// if you get here, it means the KV cache is full - try increasing it via the context size
|
|
LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
|
|
{"i", i},
|
|
{"n_batch", ret},
|
|
{"ret", ret},
|
|
});
|
|
}
|
|
|
|
for (auto& slot : slots) {
|
|
slot.state = SLOT_STATE_PROCESSING;
|
|
slot.command = SLOT_COMMAND_NONE;
|
|
slot.release();
|
|
if (ret != user_cancel) {
|
|
LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size());
|
|
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
|
}
|
|
}
|
|
break; // break loop of n_batch
|
|
}
|
|
// retry with half the batch size to try to find a free slot in the KV cache
|
|
n_batch /= 2;
|
|
i -= n_batch;
|
|
|
|
LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", {
|
|
{"i", i},
|
|
{"n_batch", n_batch},
|
|
{"ret", ret},
|
|
});
|
|
|
|
continue; // continue loop of n_batch
|
|
}
|
|
|
|
bool mtp_warmup_needed = false;
|
|
std::vector<float> batch_mtp_hidden_state;
|
|
if (params_base.has_mtp) {
|
|
for (auto& slot : slots) {
|
|
if ((slot.state == SLOT_STATE_PROCESSING && slot.n_decoded == 0) ||
|
|
(slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT)) {
|
|
bool has_tokens_for_slot = (batch_view.n_tokens > 0 && batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id);
|
|
if (has_tokens_for_slot) {
|
|
mtp_warmup_needed = true;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if (mtp_warmup_needed) {
|
|
const float* emb = llama_get_embeddings(ctx);
|
|
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
|
const int n_toks = batch_view.n_tokens;
|
|
if (emb) {
|
|
batch_mtp_hidden_state.resize(n_toks * n_embd);
|
|
memcpy(batch_mtp_hidden_state.data(), emb, n_toks * n_embd * sizeof(float));
|
|
}
|
|
}
|
|
}
|
|
|
|
for (auto& slot : slots) {
|
|
bool is_active_slot = (slot.state == SLOT_STATE_PROCESSING);
|
|
|
|
if (!is_active_slot || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
|
|
// save checkpoint during prompt processing
|
|
if (slot.command == SLOT_COMMAND_LOAD_PROMPT) {
|
|
if (slot.do_checkpoint) {
|
|
create_checkpoint(slot);
|
|
} else {
|
|
create_checkpoint_at_interval(slot, params_base);
|
|
}
|
|
}
|
|
continue; // continue loop of slots
|
|
}
|
|
|
|
// prompt evaluated for embedding
|
|
if (slot.embedding) {
|
|
send_embedding(slot, batch_view);
|
|
slot.release();
|
|
slot.i_batch = -1;
|
|
continue; // continue loop of slots
|
|
}
|
|
|
|
if (slot.n_decoded == 0 && slot.can_speculate()) {
|
|
common_speculative_begin(slot.spec, slot.cache_tokens.get_text_tokens());
|
|
}
|
|
|
|
if (slot.i_batch_dft.size() > 0) {
|
|
continue; // sample using speculative decoding
|
|
}
|
|
|
|
// RESTORE AND APPLY POSITIONAL BANS
|
|
slot.ctx_sampling->params.logit_bias = slot.logit_bias;
|
|
auto ban_it = slot.positional_bans.find(slot.n_past);
|
|
if (ban_it != slot.positional_bans.end()) {
|
|
for (llama_token tok : ban_it->second) {
|
|
slot.ctx_sampling->params.logit_bias[tok] += slot.ban_phrases_bias;
|
|
}
|
|
}
|
|
|
|
completion_token_output result;
|
|
const int tok_idx = slot.i_batch - i;
|
|
|
|
if (params_base.has_mtp && slot.n_decoded == 0) {
|
|
const float* emb_i = llama_get_embeddings_ith(ctx, tok_idx);
|
|
if (emb_i) {
|
|
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
|
slot.mtp_hidden_state.resize(n_embd);
|
|
memcpy(slot.mtp_hidden_state.data(), emb_i, n_embd * sizeof(float));
|
|
}
|
|
}
|
|
|
|
apply_server_biases(slot);
|
|
|
|
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, tok_idx);
|
|
|
|
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
|
|
|
|
slot.n_decoded += 1;
|
|
const int64_t t_current = ggml_time_us();
|
|
|
|
if (slot.n_decoded == 1) {
|
|
slot.t_start_generation = ggml_time_us();
|
|
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
|
metrics.on_prompt_eval(slot);
|
|
// create checkpoint after prompt processing ends
|
|
if (params_base.ctx_checkpoints_tolerance<=0 && params_base.do_checkpoint) {
|
|
create_checkpoint(slot);
|
|
}
|
|
}
|
|
|
|
// create checkpoint during generation
|
|
if (slot.n_decoded > 1) {
|
|
create_checkpoint_at_interval(slot, params_base);
|
|
}
|
|
|
|
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
|
|
|
result.tok = id;
|
|
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
|
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
|
if (slot.sparams.n_probs > 0) {
|
|
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
|
|
}
|
|
|
|
// no ban string for recurrent/hybrid model
|
|
if (slot.n_buffer == 0 || !params_base.can_ban_phrases) {
|
|
slot.token_buffer = { result };
|
|
send_token_results(slot.token_buffer, slot);
|
|
} else {
|
|
// buffer the result and check string ban.
|
|
// if ban, we need to go back, apply logit bias and regenerate
|
|
buffer_and_check_string_ban(slot, result);
|
|
}
|
|
|
|
common_sampler_review(slot.ctx_sampling, slot.token_buffer.size(), slot.rewind_status);
|
|
|
|
update_allowlist_state(slot);
|
|
|
|
slot.i_batch = -1;
|
|
}
|
|
if (mtp_warmup_needed && !batch_mtp_hidden_state.empty()) {
|
|
llama_context * mtp_ctx = nullptr;
|
|
for (auto & slot : slots) {
|
|
if (slot.spec && slot.has_mtp) {
|
|
llama_context * mc = common_speculative_get_mtp_ctx(slot.spec);
|
|
if (mc) { mtp_ctx = mc; break; }
|
|
}
|
|
}
|
|
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
|
|
llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data());
|
|
mtp_update_kv_cache(mtp_target, batch_view, true);
|
|
}
|
|
|
|
// speculative decoding - main model sample and accept
|
|
speculative_decoding_accept();
|
|
}
|
|
}
|
|
|
|
void server_context::update_slots() {
|
|
if (system_need_update) {
|
|
system_prompt_update();
|
|
}
|
|
// release slots
|
|
release_slots();
|
|
|
|
// check if all slots are idle
|
|
if (slots_idle()) {
|
|
return;
|
|
}
|
|
|
|
{
|
|
LOG_VERBOSE("posting NEXT_RESPONSE", {});
|
|
server_task task;
|
|
task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
|
|
task.id_target = -1;
|
|
|
|
queue_tasks.post(std::move(task));
|
|
}
|
|
|
|
// apply context-shift if needed
|
|
// TODO: simplify and improve
|
|
context_shift();
|
|
|
|
// start populating the batch for this iteration
|
|
common_batch_clear(batch);
|
|
|
|
// frist, add sampled tokens from any ongoing sequences
|
|
add_sampled_tokens(); // Prepare batch for inference
|
|
|
|
// process in chunks of params.n_batch
|
|
int32_t n_batch = llama_n_batch(ctx);
|
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
|
|
|
// track if this is an embedding or non-embedding batch
|
|
// if we've added sampled tokens above, we are in non-embedding mode
|
|
// -1: none, 0: non-embedding, 1: embedding
|
|
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
|
|
|
// next, batch any pending prompts without exceeding n_batch
|
|
batch_pending_prompt(n_ubatch, n_batch, batch_type); // Prepare batch for prompt process
|
|
|
|
if (batch.n_tokens == 0) {
|
|
LOG_VERBOSE("no tokens to decode", {});
|
|
return;
|
|
}
|
|
|
|
LOG_VERBOSE("decoding batch", {
|
|
{"n_tokens", batch.n_tokens},
|
|
});
|
|
|
|
// make sure we're in the right embedding mode
|
|
llama_set_embeddings(ctx, batch_type == 1);
|
|
|
|
if (llama_model_is_hybrid(model)) {
|
|
for (auto & slot : slots) {
|
|
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
|
|
continue;
|
|
}
|
|
slot.spec_ckpt_n_past = slot.n_past - (int32_t)(slot.drafted.size() + 1);
|
|
const size_t ckpt_size = llama_state_seq_get_size(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
slot.spec_ckpt_data.resize(ckpt_size);
|
|
const size_t written = llama_state_seq_get_data(ctx, slot.spec_ckpt_data.data(), ckpt_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
slot.spec_ckpt_valid = (written > 0);
|
|
if (slot.spec_ckpt_valid) {
|
|
// save sampler state so we can restore RNG/grammar on rejection
|
|
if (slot.spec_ckpt_sampler) {
|
|
common_sampler_free(slot.spec_ckpt_sampler);
|
|
}
|
|
slot.spec_ckpt_sampler = common_sampler_init(model, slot.sparams);
|
|
common_sampler_clone(slot.ctx_sampling, slot.spec_ckpt_sampler);
|
|
SLT_DBG(slot, "spec checkpoint saved: %zu bytes, n_past_pre_spec=%d\n", written, slot.spec_ckpt_n_past);
|
|
} else {
|
|
SLT_WRN(slot, "%s", "failed to save spec checkpoint\n");
|
|
}
|
|
}
|
|
}
|
|
|
|
// process the created batch of tokens
|
|
process_batch_tokens(n_batch); // Decode with batch
|
|
|
|
LOG_VERBOSE("run slots completed", {});
|
|
}
|
|
|
|
json server_context::model_meta() const {
|
|
return json{
|
|
{"vocab_type", llama_vocab_type(llama_model_get_vocab(model))},
|
|
{"n_vocab", llama_n_vocab(model)},
|
|
{"n_ctx_train", llama_n_ctx_train(model)},
|
|
{"n_embd", llama_model_n_embd(model)},
|
|
{"n_params", llama_model_n_params(model)},
|
|
{"size", llama_model_size(model)},
|
|
};
|
|
}
|