Files
ik_llama.cpp/examples/server/server-context.cpp
SamuelOliveirads d93dfb5e6b fix: save/restore sampler state during speculative checkpoint
When speculative decoding rejects draft tokens and restores the
recurrent state checkpoint, the sampler (RNG, grammar, prev tokens)
must also be restored to maintain consistency. Without this, the
sampler state reflects the rejected draft tokens, leading to
potential divergence.

Uses common_sampler_clone() to snapshot the sampler before the
speculative batch decode, and restores it on rejection.
2026-04-16 22:36:37 -03:00

4284 lines
172 KiB
C++

#include "server-context.h"
#include "server-common.h"
#include "server-task.h"
#include "server-queue.h"
#include "common.h"
#include "llama.h"
#include "log.h"
#include "sampling.h"
#include "speculative.h"
#include "mtmd.h"
#include "mtmd-helper.h"
#include <fstream>
#include <iostream>
#include <regex>
static void log_text(const gpt_params & params_base, const std::string & text) {
if (params_base.minilog) {
LOG_TEE("%s\n", text.c_str());
}
}
server_context::~server_context() {
if (ctx) {
llama_free(ctx);
ctx = nullptr;
}
if (model) {
llama_free_model(model);
model = nullptr;
}
// Free multimodal
mtmd_free(mctx);
// Free draft model and context if they exist
if (ctx_draft) {
llama_free(ctx_draft);
ctx_draft = nullptr;
}
if (model_draft) {
llama_free_model(model_draft);
model_draft = nullptr;
}
// Clear any sampling context
for (server_slot& slot : slots) {
if (slot.ctx_sampling != nullptr) {
common_sampler_free(slot.ctx_sampling);
}
if (slot.spec_ckpt_sampler != nullptr) {
common_sampler_free(slot.spec_ckpt_sampler);
}
if (slot.ctx_dft) {
llama_free(slot.ctx_dft);
}
common_speculative_free(slot.spec);
llama_batch_free(slot.batch_spec);
}
llama_batch_free(batch);
}
bool server_context::load_model(const gpt_params& params_) {
params_base = params_;
llama_init_result llama_init = llama_init_from_gpt_params(params_base);
model = llama_init.model;
ctx = llama_init.context;
lora_adapters = llama_init.lora_adapters;
if (model == nullptr) {
LOG_ERROR("unable to load model", { {"model", params_base.model} });
return false;
}
n_ctx = llama_n_ctx(ctx);
add_bos_token = llama_should_add_bos_token(model);
has_eos_token = llama_add_eos_token(model) != 1;
bool has_draft_model = !params_base.speculative.model.empty() || !params_base.speculative.params.empty();
std::string& mmproj_path = params_base.mmproj.path;
if (!mmproj_path.empty()) {
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = params_base.mmproj_use_gpu;
mparams.print_timings = false;
mparams.n_threads = params_base.n_threads;
mparams.flash_attn_type = params_base.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED;
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
mparams.image_min_tokens = params_base.image_min_tokens;
mparams.image_max_tokens = params_base.image_max_tokens;
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
if (mctx == nullptr) {
LOG_ERROR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
return false;
}
LOG_INFO("loaded multimodal model, '%s'\n", mmproj_path.c_str());
//if (params.n_cache_reuse) {
// params_base.n_cache_reuse = 0;
// SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
//}
if (has_draft_model) {
LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal");
return false;
}
if (params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE) {
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled");
}
}
// Load draft model for speculative decoding if specified
if (has_draft_model) {
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
gpt_params params_dft;
params_dft.devices = params_base.speculative.devices;
params_dft.model = params_base.speculative.model;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.rpc_servers = params_base.rpc_servers;
params_dft.cache_type_k = params_base.speculative.cache_type_k.empty() ? params_base.cache_type_k : params_base.speculative.cache_type_k;
params_dft.cache_type_v = params_base.speculative.cache_type_v.empty() ? params_base.cache_type_v : params_base.speculative.cache_type_v;
params_dft.flash_attn = params_base.flash_attn;
if (!params_base.speculative.params.empty()) {
auto [argc, argv] = parse_command_line("llama-server " + params_base.speculative.params);
if (!gpt_params_parse(argc, argv, params_dft)) {
gpt_params_print_usage(argc, argv, params_dft);
free_command_line(argc, argv);
return false;
};
free_command_line(argc, argv);
}
LOG_INFO("", { {"model", params_dft.model} });
if (params_dft.n_ctx == 0) {
params_dft.n_ctx = params_base.speculative.n_ctx;
}
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
params_dft.n_parallel = 1;
params_dft.n_batch = params_dft.n_ctx;
params_base.speculative.mparams_dft.path = params_dft.model; //
llama_model_params mparams_dft = common_model_params_to_llama(params_dft);
llama_model * model_dft = llama_model_load_from_file(params_dft.model.c_str(), mparams_dft);
if (model_dft == nullptr) {
LOG_ERROR("failed to load draft model", { {"model", params_base.speculative.model} });
return false;
}
cparams_dft = common_context_params_to_llama(params_dft);
params_base.speculative.model_dft = model_dft;
params_base.speculative.cparams_dft = cparams_dft;
}
else if (params_base.has_mtp && llama_model_n_nextn_layer(model) == 0) {
LOG_WARNING("WARNING: -mtp flag provided, but model has 0 NextN layers. MTP will be disabled.\n", {});
params_base.has_mtp = false;
}
return true;
}
void server_context::init() {
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
LOG_INFO("initializing slots", { {"n_slots", params_base.n_parallel} });
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
slot.id = i;
slot.ctx = ctx;
slot.n_ctx = n_ctx_slot;
slot.n_predict = params_base.n_predict;
slot.mctx = mctx;
slot.cache_tokens.has_mtmd = mctx != nullptr;
slot.params.think_tokens = params_base.think_tokens;
if (params_base.think_tokens.exclude) {
SRV_WRN("Exclude reasoning tokens when selecting slot based on similarity: start: %s, end: %s\nuse `--reasoning-tokens none` to disable.\n", params_base.think_tokens.begin.c_str(), params_base.think_tokens.end.c_str() );
}
else {
SRV_WRN("%s", "Include reasoning tokens when selecting slot based on similarity\nuse `--reasoning-tokens auto` to exclude reasoning tokens.\n");
}
LOG_INFO("new slot", {
{"id_slot", slot.id},
{"n_ctx_slot", slot.n_ctx}
});
const int ga_n = params_base.grp_attn_n;
const int ga_w = params_base.grp_attn_w;
if (ga_n != 1) {
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
LOG_INFO("slot self-extend", {
{"id_slot", slot.id},
{"ga_n", ga_n},
{"ga_w", ga_w}
});
}
slot.ga_i = 0;
slot.ga_n = ga_n;
slot.ga_w = ga_w;
slot.sparams = params_base.sparams;
if (params_base.has_mtp) {
if (llama_model_n_nextn_layer(model) > 0) {
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
params_base.speculative.cparams_dft = common_context_params_to_llama(params_base);
params_base.speculative.cparams_dft.mtp = true;
params_base.speculative.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
params_base.speculative.cparams_dft.embeddings = true;
slot.has_mtp = true;
slot.params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
slot.params.speculative.n_min = 0;
slot.params.speculative.cparams_dft = params_base.speculative.cparams_dft;
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens);
SRV_INF("%s\n", "MTP needs embeddings on decode, enabling");
llama_set_embeddings(ctx, true);
}
else {
SRV_WRN("%s\n", "MTP enabled via flag, but model has 0 NextN layers. Disabling speculative.");
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
slot.has_mtp = false;
}
}
bool can_spec = true;
if (!params_base.dry_run) {
can_spec = common_speculative_is_compat(ctx);
}
if (!can_spec) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
// try speculative decoding
if (can_spec) {
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
if (slot.spec) {
if (mctx) {
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
return;
}
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
} else {
if (slot.has_mtp) {
SRV_ERR("%s", "failed to initialize MTP speculative context, aborting\n");
GGML_ABORT("MTP context creation failed");
} else {
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
}
}
}
slot.reset();
slots.push_back(std::move(slot));
}
default_generation_settings_for_props = get_formated_generation(slots.front());
default_generation_settings_for_props["seed"] = -1;
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
{
const int32_t n_batch = llama_n_batch(ctx);
// only a single seq_id per token is needed
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
}
metrics.init();
if (params_base.cache_ram_mib != 0) {
if (params_base.cache_ram_mib < 0) {
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %s\n", "no limit");
}
else {
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
}
LLAMA_LOG_INFO("%s", "use `--cache-ram 0` to disable the prompt cache\n");
// only apply ram size limit. No token limit for now.
prompt_cache = std::make_unique<server_prompt_cache>(ctx, params_base.cache_ram_mib, 0);
}
else {
LLAMA_LOG_INFO("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
}
// populate chat template params
{
common_chat_templates_ptr chat_templates;
try {
chat_templates = common_chat_templates_init(model, params_base.chat_template);
LOG_INF("%s: chat template, example_format: '%s'\n", __func__,
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
}
catch (const std::exception & e) {
SRV_ERR("%s: chat template parsing error: %s\n", __func__, e.what());
SRV_ERR("%s: please consider enabling jinja via --jinja, or use a custom chat template via --chat-template\n", __func__);
SRV_ERR("%s: for example: --chat-template chatml\n", __func__);
}
// thinking is enabled if:
// 1. It's not explicitly disabled (reasoning_budget == 0)
// 2. The chat template supports it
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking);
chat_params = {
/* use_jinja */ params_base.use_jinja,
/* use_peg */ params_base.use_peg,
/* prefill_assistant */ params_base.prefill_assistant,
/* reasoning_format */ params_base.reasoning_format,
/* chat_template_kwargs */ params_base.default_template_kwargs,
/* tmpls */ std::move(chat_templates),
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio(mctx) : false,
/* enable_thinking */ enable_thinking,
// /* media_path */ params_base.media_path,
};
}
}
void server_slot::prompt_save(server_prompt_cache& prompt_cache) const {
assert(server_cached_prompt.data.size() == 0);
const size_t cur_size = llama_state_seq_get_size(ctx, id, 0);
LLAMA_LOG_INFO(" - saving prompt with length %d, total state size = %.3f MiB\n",
(int)server_cached_prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
auto* cur = prompt_cache.alloc(server_cached_prompt, cur_size);
if (cur == nullptr) {
return;
}
llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id, 0);
}
void server_slot::prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) {
bool res = prompt_cache.load(server_cached_prompt, tokens, ctx, id);
if (!res) {
LLAMA_LOG_INFO("failed to load prompt from cache\n");
}
}
void server_slot::reset() {
n_prompt_tokens = 0;
last_gentxt_size = 0;
generated_text = "";
truncated = false;
stopped_eos = false;
stopped_word = false;
stopped_limit = false;
stopping_word = "";
n_past = 0;
n_past_prompt = 0;
n_sent_text = 0;
drafted.clear();
i_batch_dft.clear();
spec_ckpt_valid = false;
if (spec_ckpt_sampler) {
common_sampler_free(spec_ckpt_sampler);
spec_ckpt_sampler = nullptr;
}
n_sent_token_probs = 0;
infill = false;
ga_i = 0;
n_past_se = 0;
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
logit_bias.clear();
token_buffer.clear();
rewind_count = 0;
n_buffer = 0;
rewind_status = false;
generated_token_probs.clear();
checkpoint_pos = 0;
image_just_processed = false;
do_checkpoint = false;
positional_bans.clear();
ban_phrases.clear();
ban_regex.clear();
ban_regex_ci.clear();
allow_ruless.clear();
allow_pieces.clear();
allow_kws.clear();
allow_kw_delay = 0;
allow_idx = 0;
// Reset speculative decoding stats
n_draft_total = 0;
n_draft_accepted = 0;
chat_msg = {};
json_schema = json();
generated_tool_call_ids.clear();
anthropic_thinking_block_started = false;
anthropic_text_block_started = false;
oai_resp_thinking_block_started = false;
oai_resp_text_block_started = false;
oai_resp_id.clear();
oai_resp_reasoning_id.clear();
oai_resp_message_id.clear();
oai_resp_fc_id.clear();
task.reset();
}
bool server_slot::need_embd() const {
return embedding || has_mtp;
}
bool server_slot::has_budget(gpt_params& global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
}
n_remaining = -1;
if (params.n_predict != -1) {
n_remaining = params.n_predict - n_decoded;
}
else if (global_params.n_predict != -1) {
n_remaining = global_params.n_predict - n_decoded;
}
return n_remaining > 0; // no budget
}
bool server_slot::available() const {
return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE;
}
bool server_slot::is_processing() const {
return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING;
}
void server_slot::add_token_string(const completion_token_output& token) {
if (command == SLOT_COMMAND_RELEASE) {
return;
}
generated_token_probs.push_back(token);
}
bool server_slot::can_speculate() const {
return (!!spec || has_mtp);
}
int server_slot::get_n_draft_max() const {
if (!can_speculate()) {
return 0;
}
// determine the max draft that fits the current slot state
int n_draft_max = params.speculative.n_max;
// note: slot.prompt is not yet expanded with the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, n_ctx - n_past - 2);
if (n_remaining > 0) {
n_draft_max = std::min(n_draft_max, n_remaining - 1);
}
SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
if (n_draft_max < params.speculative.n_min) {
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, params.speculative.n_min);
n_draft_max = 0;
}
return n_draft_max;
}
void server_slot::release() {
if (state == SLOT_STATE_PROCESSING) {
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
command = SLOT_COMMAND_RELEASE;
state = SLOT_STATE_IDLE;
task.reset();
llama_decode_reset();
}
}
json server_slot::get_formated_timings() const {
return json{
{"prompt_n", n_prompt_tokens_processed},
{"prompt_ms", t_prompt_processing},
{"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed},
{"prompt_per_second", 1e3 / t_prompt_processing * n_prompt_tokens_processed},
{"predicted_n", n_decoded},
{"predicted_ms", t_token_generation},
{"predicted_per_token_ms", t_token_generation / n_decoded},
{"predicted_per_second", 1e3 / t_token_generation * n_decoded},
{"n_ctx", n_ctx},
{"n_past", n_past},
};
}
result_timings server_slot::get_timings() const {
result_timings timings;
timings.prompt_n = n_prompt_tokens_processed;
timings.prompt_ms = t_prompt_processing;
timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
timings.predicted_n = n_decoded;
timings.predicted_ms = t_token_generation;
timings.predicted_per_token_ms = t_token_generation / n_decoded;
timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
timings.n_ctx = n_ctx;
timings.n_past = n_past;
// Add speculative metrics
if (n_draft_total > 0) {
timings.draft_n = n_draft_total;
timings.draft_n_accepted = n_draft_accepted;
}
return timings;
}
const common_chat_msg& server_slot::update_chat_msg(std::vector<common_chat_msg_diff>& diffs) {
auto previous_msg = chat_msg;
auto new_msg = common_chat_parse(
generated_text,
/* is_partial= */ stop != STOP_TYPE_EOS,
params.oaicompat_chat_syntax);
if (!new_msg.empty()) {
new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
chat_msg = new_msg;
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
}
//LLAMA_LOG_DEBUG("Parsing chat message: %s\n", generated_text.c_str());
//LLAMA_LOG_DEBUG("Parsing chat message: %s\n", chat_msg.reasoning_content.c_str());
//LLAMA_LOG_DEBUG("Parsing chat message: %s\n", chat_msg.content.c_str());
return chat_msg;
}
size_t server_slot::find_stopping_strings(const std::string& text, const size_t last_token_size, bool is_full_stop) {
size_t stop_pos = std::string::npos;
for (const std::string& word : params.antiprompt) {
size_t pos;
if (is_full_stop) {
const size_t tmp = word.size() + last_token_size;
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
pos = text.find(word, from_pos);
}
else {
pos = string_find_partial_stop(text, word);
}
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
if (is_full_stop) {
stopped_word = true;
stopping_word = word;
has_next_token = false;
}
stop_pos = pos;
}
}
return stop_pos;
}
void server_slot::print_timings() const {
char buffer[512];
double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
double t_gen = t_token_generation / n_decoded;
double n_gen_second = 1e3 / t_token_generation * n_decoded;
SLT_INF(*this,
"\n"
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
" total time = %10.2f ms / %5d tokens\n",
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
t_token_generation, n_decoded, t_gen, n_gen_second,
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
if (n_draft_total > 0) {
const float draft_ratio = (float)n_draft_accepted / n_draft_total;
SLT_CNT(*this,
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
draft_ratio, n_draft_accepted, n_draft_total
);
}
common_speculative_print_stats(spec, n_gen_second, n_decoded, n_past,
const_cast<common_params_speculative *>(&params.speculative));
}
void server_metrics::init() {
t_start = ggml_time_us();
}
void server_metrics::on_prompt_eval(const server_slot& slot) {
n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed;
n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
t_prompt_processing += slot.t_prompt_processing;
t_prompt_processing_total += slot.t_prompt_processing;
}
void server_metrics::on_prediction(const server_slot& slot) {
n_tokens_predicted_total += slot.n_decoded;
n_tokens_predicted += slot.n_decoded;
t_tokens_generation += slot.t_token_generation;
t_tokens_generation_total += slot.t_token_generation;
}
void server_metrics::reset_bucket() {
n_prompt_tokens_processed = 0;
t_prompt_processing = 0;
n_tokens_predicted = 0;
t_tokens_generation = 0;
}
std::vector<llama_token> server_context::tokenize(const json& json_prompt, bool add_special) const {
// TODO: currently, we tokenize using special tokens by default
// this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
// but it's better compared to completely ignoring ChatML and other chat templates
const bool TMP_FORCE_SPECIAL = true;
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
std::vector<llama_token> prompt_tokens;
if (json_prompt.is_array()) {
bool first = true;
for (const auto& p : json_prompt) {
if (p.is_string()) {
auto s = p.template get<std::string>();
std::vector<llama_token> p;
if (first) {
p = ::common_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
first = false;
}
else {
p = ::common_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
}
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
}
else {
if (first) {
first = false;
}
prompt_tokens.push_back(p.template get<llama_token>());
}
}
}
else {
auto s = json_prompt.template get<std::string>();
prompt_tokens = ::common_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
}
return prompt_tokens;
}
server_slot* server_context::get_slot_by_id(int id) {
for (server_slot& slot : slots) {
if (slot.id == id) {
return &slot;
}
}
return nullptr;
}
float server_context::calculate_slot_f_keep(const server_slot & slot, llama_context * ctx,const server_tokens & a, const server_tokens & b) {
float f_keep = 0.0f;
if (!a.empty()) {
if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && b.size() >= slot.n_ctx) {
f_keep = a.get_cached_tokens_similarity(slot.ctx, b, slot.params.n_keep + add_bos_token, slot.n_discarded_prompt);
}
else {
f_keep = a.get_cached_tokens_similarity(slot.ctx, b, 0, 0);
}
}
return f_keep;
}
std::pair<common_prefix, float> server_context::calculate_slot_similarity(const server_slot& slot, llama_context* ctx, const server_tokens& a, const server_tokens& b) {
std::pair<common_prefix, float> sim;
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
common_prefix lcp_len = a.get_common_prefix(slot.ctx, b);
// fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length
float sim_cur = a.get_tokens_similarity(slot.ctx, b, 0, 0);
// handle context shift
if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && b.size() >= slot.n_ctx) {
float sim_cur_ctx_shift = a.get_tokens_similarity(slot.ctx, b, slot.n_kept_prompt, slot.n_discarded_prompt);
if (sim_cur_ctx_shift > sim_cur) {
sim_cur = sim_cur_ctx_shift;
}
}
sim.first = lcp_len;
sim.second = sim_cur;
return sim;
}
void server_context::copy_data_to_cached_prompt(const server_tokens & tokens, server_slot & slot) {
slot.server_cached_prompt.tokens = tokens.clone(); // copy cache tokens
slot.server_cached_prompt.n_discarded_prompt = slot.n_discarded_prompt;
slot.server_cached_prompt.n_kept_prompt = slot.n_kept_prompt;
slot.server_cached_prompt.think_tokens = slot.params.think_tokens;
}
server_slot* server_context::get_available_slot(const server_task& task) {
server_slot* ret = nullptr;
bool update_cache = false;
// find the slot that has at least n% prompt similarity
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
int max_lcp_len = 0;
float sim_best = 0;
for (server_slot& slot : slots) {
// skip the slot if it is not available
if (!slot.available()) {
continue;
}
auto& cache_tokens = slot.cache_tokens;
// skip the slot if it does not contains prompt
if (cache_tokens.empty()) {
continue;
}
std::pair<common_prefix, float> sim;
if (slot.params.think_tokens.exclude) {
server_tokens cache_tokens_exclude_think = slot.cache_tokens.get_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
server_tokens prompt_tokens_exclude_think = task.tokens.get_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
sim = calculate_slot_similarity(slot, ctx, cache_tokens_exclude_think, prompt_tokens_exclude_think);
}
else {
sim = calculate_slot_similarity(slot, ctx, cache_tokens, task.tokens);
}
common_prefix lcp_len = sim.first;
float sim_cur = sim.second;
// select the current slot if the criteria match
if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
sim_best = sim_cur;
max_lcp_len = lcp_len.first;
ret = &slot;
}
}
if (ret != nullptr) {
LOG_VERBOSE("selected slot by lcp similarity", {
{"id_slot", ret->id},
{"max_lcp_len", max_lcp_len},
{"similarity", sim_best},
});
}
}
// find the slot that has been least recently used
if (ret == nullptr) {
int64_t t_last = ggml_time_us();
for (server_slot& slot : slots) {
// skip the slot if it is not available
if (!slot.available()) {
continue;
}
// select the current slot if the criteria match
if (slot.t_last_used < t_last) {
t_last = slot.t_last_used;
ret = &slot;
}
}
if (ret != nullptr) {
LOG_VERBOSE("selected slot by lru", {
{"id_slot", ret->id},
{"t_last", t_last},
});
}
}
if (ret) {
auto& tokens = ret->cache_tokens;
float f_keep = 0;
size_t cache_token_size = tokens.size();
if (!tokens.empty()) {
if (ret->params.think_tokens.exclude) {
server_tokens cache_exclude_think = tokens.get_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
server_tokens prompt_exclude_think = task.tokens.get_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
cache_token_size = cache_exclude_think.size();
f_keep = calculate_slot_f_keep(*ret, ret->ctx, cache_exclude_think, prompt_exclude_think);
}
else {
f_keep = calculate_slot_f_keep(*ret, ret->ctx, tokens, task.tokens);
}
// if we are about to lose a large portion of the existing context - save it in the prompt cache
if (f_keep < cache_ram_similarity) {
update_cache = true;
}
}
update_cache = update_cache && prompt_cache;
// cache prompts only for completion tasks
update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
// don't update the cache if the slot's context is above cache_ram_n_min
update_cache = update_cache && cache_token_size >= cache_ram_n_min;
LLAMA_LOG_INFO("======== Prompt cache: cache size: %d, n_keep: %d, n_discarded_prompt: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n",
(int)tokens.size(), ret->n_kept_prompt, ret->n_discarded_prompt, cache_ram_n_min, f_keep, cache_ram_similarity);
if (update_cache) {
const int64_t t_start = ggml_time_us();
LLAMA_LOG_INFO("updating prompt cache\n");
// copy cache tokens
copy_data_to_cached_prompt(tokens, *ret);
ret->prompt_save(*prompt_cache);
LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
}
// has prompts saved earlier to load
if (prompt_cache && !prompt_cache->states.empty()) {
const int64_t t_start = ggml_time_us();
copy_data_to_cached_prompt(tokens, *ret);
ret->prompt_load(*prompt_cache, task.tokens);
prompt_cache->update();
ret->cache_tokens = ret->server_cached_prompt.tokens.clone(); // recover cache tokens
ret->n_discarded_prompt = ret->server_cached_prompt.n_discarded_prompt;
ret->n_kept_prompt = ret->server_cached_prompt.n_kept_prompt;
LLAMA_LOG_INFO("prompt cache load took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
}
}
return ret;
}
int32_t server_context::populate_vocab_pieces() {
const int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
if (vocab_pieces.size() == n_vocab) {
return n_vocab;
}
vocab_pieces.clear();
vocab_pieces.reserve(n_vocab);
for (int32_t id = 0; id < n_vocab; ++id) {
vocab_pieces.push_back(common_token_to_piece(ctx, id, true));
}
return n_vocab;
}
bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) {
slot_params defaults;
defaults.speculative = params_base.speculative;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
common_params_sampling default_sparams = params_base.sparams;
auto& data = task.data;
const llama_vocab* vocab = llama_model_get_vocab(model);
if (data.count("__oaicompat") != 0) {
slot.oaicompat = true;
slot.oaicompat_model = task.params.oaicompat_model;
}
else {
slot.oaicompat = false;
slot.oaicompat_model = "";
}
slot.params.oaicompat = task.params.oaicompat;
slot.params.oaicompat_cmpl_id =task.params.oaicompat_cmpl_id;
slot.oai_resp_thinking_block_started = false;
slot.oai_resp_text_block_started = false;
slot.oai_resp_id = "resp_" + random_string();
slot.oai_resp_reasoning_id = "rs_" + random_string();
slot.oai_resp_message_id = "msg_" + random_string();
slot.oai_resp_fc_id.clear();
slot.params.timings_per_token = json_value(data, "timings_per_token", false);
slot.params.stream = json_value(data, "stream", false);
auto stream_opt = json_value(data, "stream_options", json::object());
slot.params.include_usage = json_value(stream_opt, "include_usage", false);
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
slot.saturate_predict = json_value(data, "saturate_predict", false);
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
slot.sparams.top_n_sigma = json_value(data, "top_n_sigma", default_sparams.top_n_sigma);
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot.sparams.adaptive_target = json_value(data, "adaptive_target", default_sparams.adaptive_target);
slot.sparams.adaptive_decay = json_value(data, "adaptive_decay", default_sparams.adaptive_decay);
slot.sparams.adaptive_updt_w_cur = json_value(data, "adaptive_updt_w_cur", default_sparams.adaptive_updt_w_cur);
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", defaults.n_discard);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
slot.params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
// speculative decoding parameters
slot.params.speculative.n_max = json_value(data, "speculative.n_max", params_base.speculative.n_max);
slot.params.speculative.n_min = json_value(data, "speculative.n_min", params_base.speculative.n_min);
slot.params.speculative.p_min = json_value(data, "speculative.p_min", params_base.speculative.p_min);
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0);
slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
slot.params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
// Clamp speculative parameters
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0);
slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
slot.params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n);
slot.params.speculative.ngram_size_m = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m);
slot.params.speculative.ngram_min_hits = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits);
slot.params.speculative.ngram_size_n = std::max(std::min(1, (int)slot.params.speculative.ngram_size_n), 1024);
slot.params.speculative.ngram_size_m = std::max(std::min(1, (int)slot.params.speculative.ngram_size_m), 1024);
slot.params.speculative.ngram_min_hits = std::max(std::min(1, (int)slot.params.speculative.ngram_min_hits), 1024);
if (slot.sparams.penalty_last_n < -1) {
throw std::runtime_error("Error: repeat_last_n must be >= -1");
}
if (slot.sparams.dry_penalty_last_n < -1) {
throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
}
if (slot.sparams.penalty_last_n == -1) {
// note: should be the slot's context and not the full context, but it's ok
slot.sparams.penalty_last_n = llama_n_ctx(ctx);
}
if (slot.sparams.dry_penalty_last_n == -1) {
slot.sparams.dry_penalty_last_n = llama_n_ctx(ctx);
}
if (slot.sparams.dry_base < 1.0f)
{
slot.sparams.dry_base = default_sparams.dry_base;
}
// sequence breakers for DRY
{
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
if (data.contains("dry_sequence_breakers")) {
slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
if (slot.sparams.dry_sequence_breakers.empty()) {
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
}
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
LLAMA_LOG_DEBUG("JSON schema: %s\n", schema.dump(2).c_str());
slot.sparams.grammar = json_schema_to_grammar(schema);
LLAMA_LOG_DEBUG("Converted grammar: %s\n", slot.sparams.grammar.c_str());
}
catch (const std::exception& e) {
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
}
}
else {
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
LLAMA_LOG_DEBUG("Grammar: %s\n", slot.sparams.grammar.c_str());
slot.sparams.grammar_lazy = json_value(data, "grammar_lazy", default_sparams.grammar_lazy);
LLAMA_LOG_DEBUG("Grammar lazy: %s\n", slot.sparams.grammar_lazy ? "true" : "false");
}
if (slot.params.cache_prompt && slot.ga_n != 1) {
LOG_WARNING("cache_prompt is not supported with group-attention", {});
slot.params.cache_prompt = false;
}
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
// Might be better to reject the request with a 400 ?
LOG_WARNING("Max tokens to predict exceeds server configuration", {
{"params.n_predict", slot.params.n_predict},
{"slot.n_predict", slot.n_predict},
});
slot.params.n_predict = slot.n_predict;
}
// infill
slot.params.input_prefix = json_value(data, "input_prefix", defaults.input_prefix);
slot.params.input_suffix = json_value(data, "input_suffix", defaults.input_suffix);
// get prompt
if (!task.infill) {
slot.prompt_tokens = std::move(task.tokens);
const auto & prompt = data.find("prompt");
if (prompt != data.end()) {
if (prompt->is_string() ||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
slot.prompt = *prompt;
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) {
slot.prompt = *prompt;
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
slot.prompt = prompt->at(0);
}
}
}
// penalize user-provided tokens
{
slot.sparams.penalty_prompt_tokens.clear();
slot.sparams.use_penalty_prompt_tokens = false;
const auto& penalty_prompt = data.find("penalty_prompt");
if (penalty_prompt != data.end()) {
if (penalty_prompt->is_string()) {
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
slot.sparams.penalty_prompt_tokens = common_tokenize(model, penalty_prompt_string, false);
if (slot.params.n_predict > 0) {
slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
}
slot.sparams.use_penalty_prompt_tokens = true;
LOG_VERBOSE("penalty_prompt_tokens", {
{"id_slot", slot.id},
{"tokens", slot.sparams.penalty_prompt_tokens},
});
}
else if (penalty_prompt->is_array()) {
const auto n_tokens = penalty_prompt->size();
slot.sparams.penalty_prompt_tokens.clear();
slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
const int n_vocab = llama_n_vocab(model);
for (const auto& penalty_token : *penalty_prompt) {
if (penalty_token.is_number_integer()) {
const auto tok = penalty_token.get<llama_token>();
if (tok >= 0 && tok < n_vocab) {
slot.sparams.penalty_prompt_tokens.push_back(tok);
}
}
}
slot.sparams.use_penalty_prompt_tokens = true;
LOG_VERBOSE("penalty_prompt_tokens", {
{"id_slot", slot.id},
{"tokens", slot.sparams.penalty_prompt_tokens},
});
}
}
}
{
auto it = data.find("chat_format");
if (it != data.end()) {
slot.params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
LLAMA_LOG_DEBUG("Chat format: %s\n", common_chat_format_name(slot.params.oaicompat_chat_syntax.format));
}
else {
slot.params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
}
common_reasoning_format reasoning_format = params_base.reasoning_format;
if (data.contains("reasoning_format")) {
reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
}
slot.params.oaicompat_chat_syntax.reasoning_format = reasoning_format;
slot.params.oaicompat_chat_syntax.reasoning_in_content = slot.params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
slot.params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
if (data.contains("chat_parser")) {
slot.params.oaicompat_chat_syntax.parser.load(data.at("chat_parser").get<std::string>());
}
slot.params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
}
{
const auto preserved_tokens = data.find("preserved_tokens");
if (preserved_tokens != data.end()) {
slot.sparams.preserved_tokens.clear();
for (const auto& t : *preserved_tokens) {
auto ids = common_tokenize(model, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG("Preserved token: %d\n", ids[0]);
slot.sparams.preserved_tokens.insert(ids[0]);
}
else {
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
LOG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
}
}
}
const auto grammar_triggers = data.find("grammar_triggers");
if (grammar_triggers != data.end()) {
slot.sparams.grammar_triggers.clear();
for (const auto& t : *grammar_triggers) {
server_grammar_trigger ct(t);
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
const auto& word = ct.value.value;
auto ids = common_tokenize(model, word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
auto token = ids[0];
if (std::find(slot.sparams.preserved_tokens.begin(), slot.sparams.preserved_tokens.end(), (llama_token)token) == slot.sparams.preserved_tokens.end()) {
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
}
LOG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
common_grammar_trigger trigger;
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
trigger.value = word;
trigger.token = token;
slot.sparams.grammar_triggers.push_back(std::move(trigger));
}
else {
LOG("Grammar trigger word: `%s`\n", word.c_str());
slot.sparams.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word });
}
}
else {
//slot.sparams.grammar_triggers.push_back(ct);
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
LLAMA_LOG_DEBUG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
}
else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
LLAMA_LOG_DEBUG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
}
else {
throw std::runtime_error("Unknown grammar trigger type");
}
slot.sparams.grammar_triggers.emplace_back(std::move(ct.value));
}
}
}
if (slot.sparams.grammar_lazy && slot.sparams.grammar_triggers.empty()) {
throw std::runtime_error("Error: no triggers set for lazy grammar!");
}
}
{ // apply logit bias
const auto& logit_bias = data.find("logit_bias");
if (logit_bias != data.end() && (logit_bias->is_object() || logit_bias->is_array())) {
slot.sparams.logit_bias.clear(); // only clear if user sets it
}
if (logit_bias != data.end() && logit_bias->is_array()) {
const int n_vocab = llama_n_vocab(model);
for (const auto& el : *logit_bias) {
// TODO: we may want to throw errors here, in case "el" is incorrect
if (el.is_array() && el.size() == 2) {
float bias;
if (el[1].is_number()) {
bias = el[1].get<float>();
}
else if (el[1].is_boolean() && !el[1].get<bool>()) {
bias = -INFINITY;
}
else {
continue;
}
if (el[0].is_number_integer()) {
llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab) {
slot.sparams.logit_bias[tok] = bias;
}
}
else if (el[0].is_string()) {
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
for (auto tok : toks) {
slot.sparams.logit_bias[tok] = bias;
}
}
}
}
}
else if (logit_bias != data.end() && logit_bias->is_object()) {
const int n_vocab = llama_vocab_n_tokens(vocab);
for (const auto& el : logit_bias->items()) {
float bias;
const auto& key = el.key();
const auto& value = el.value();
if (value.is_number()) {
bias = value.get<float>();
}
else if (value.is_boolean() && !value.get<bool>()) {
bias = -INFINITY;
}
else {
continue;
}
char* end;
llama_token tok = strtol(key.c_str(), &end, 10);
if (*end == 0) {
if (tok >= 0 && tok < n_vocab) {
slot.sparams.logit_bias[tok] = bias;
}
}
else {
auto toks = common_tokenize(model, key, false);
for (auto tok : toks) {
slot.sparams.logit_bias[tok] = bias;
}
}
}
}
if (json_value(data, "ignore_eos", false) && has_eos_token) {
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
}
}
{
// ban string
int32_t banbuffer_size = json_value(data, "banbuffer_size", 0);
slot.n_buffer = 0; // Ensure buffer calculation starts fresh for this slot
slot.rewind_count_max = json_value(data, "rewind_count_max", -1);
const auto& banned_strings = data.find("banned_strings");
if (banned_strings != data.end() && banned_strings->is_array()) {
slot.ban_phrases.clear();
for (const auto& val : data["banned_strings"]) {
if (val.is_string()) {
std::string s = val.get<std::string>();
if (!s.empty()) {
s = string_lower(s);
// Use string length instead of token count
if (s.length() > slot.n_buffer) {
slot.n_buffer = s.length();
}
slot.ban_phrases.push_back(s);
}
}
}
std::sort(slot.ban_phrases.begin(), slot.ban_phrases.end(), [](const std::string& a, const std::string& b) {
return a.length() > b.length();
});
} else if (params_base.ban_phrases.size() > 0) {
if (params_base.n_buffer == 0) {
slot.ban_phrases.clear();
std::sort(params_base.ban_phrases.begin(), params_base.ban_phrases.end(), [](const std::string & a, const std::string & b) {
return a.length() > b.length();
});
for (auto & val : params_base.ban_phrases) {
if (!val.empty()) {
val = string_lower(val);
// Use string length instead of token count
if (val.length() > slot.n_buffer) {
slot.n_buffer = val.length();
}
slot.ban_phrases.push_back(val);
}
}
params_base.n_buffer = slot.n_buffer + 1; // buffer is longest string + 1
} else {
slot.ban_phrases = params_base.ban_phrases;
slot.n_buffer = params_base.n_buffer;
}
}
// ban regex
slot.ban_regex.clear();
const auto& banned_regex = data.find("banned_regex");
if (banned_regex != data.end() && banned_regex->is_array()) {
for (const auto& val : data["banned_regex"]) {
if (val.is_string()) {
std::string s = val.get<std::string>();
if (!s.empty()) {
try {
std::regex re(s);
slot.ban_regex.push_back(s);
if (s.length() > slot.n_buffer) {
slot.n_buffer = s.length();
}
} catch (const std::regex_error& e) {
send_error(task, "Invalid regex in banned_regex: " + s, ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
}
}
}
// ban regex case insensitive
slot.ban_regex_ci.clear();
const auto& banned_regex_ci = data.find("banned_regex_case_insensitive");
if (banned_regex_ci != data.end() && banned_regex_ci->is_array()) {
for (const auto& val : data["banned_regex_case_insensitive"]) {
if (val.is_string()) {
std::string s = val.get<std::string>();
if (!s.empty()) {
try {
std::regex re(s, std::regex_constants::icase);
slot.ban_regex_ci.push_back(s);
if (s.length() > slot.n_buffer) {
slot.n_buffer = s.length();
}
} catch (const std::regex_error& e) {
send_error(task, "Invalid regex in banned_regex_case_insensitive: " + s, ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
}
}
}
if (banbuffer_size > 0) {
slot.n_buffer = banbuffer_size;
} else {
slot.n_buffer = slot.n_buffer + 1; // buffer is longest string/regex + 1
}
slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
slot.banned_n = json_value(data, "banned_n", params_base.banned_n);
}
do // populate allowlist biases
{
// TODO: JSON parsing for rules and keywords
slot.allow_ruless = params_base.allow_ruless;
if (slot.allow_ruless.size() == 0) {
slot.allow_biasess.clear();
break;
}
slot.allow_kws = params_base.allow_kws;
slot.allow_pieces = params_base.allow_pieces;
const auto& allowlist_piece_array = data.find("allowlist_piece_array");
if (allowlist_piece_array != data.end() && allowlist_piece_array->is_array()) {
slot.allow_pieces.clear();
for (const auto& piece: *allowlist_piece_array) {
if (piece.is_string()) {
slot.allow_pieces.push_back(piece.get<std::string>());
}
}
}
slot.allow_kw_delay = json_value(data, "allowlist_keyword_delay", params_base.allow_kw_delay);
// end of allowlist criteria update
const int32_t n_vocab = populate_vocab_pieces();
std::unordered_set<llama_token> allow_settoken;
for (const auto& piece: slot.allow_pieces) {
for (const auto token: common_tokenize(model, piece, false, true)) {
allow_settoken.insert(token);
}
}
auto n_rules = slot.allow_ruless.size();
if (n_rules > slot.allow_kws.size() + 1) {
// one more rules than keyword, last rules do not expire
n_rules = slot.allow_kws.size() + 1;
slot.allow_ruless.resize(n_rules);
} else if (n_rules < slot.allow_kws.size()) {
// every rules expire
slot.allow_kws.resize(n_rules);
}
slot.allow_biasess.resize(n_rules);
for (size_t i = 0; i < n_rules; ++i) {
const auto& rules = slot.allow_ruless[i];
if ((i < slot.allow_ruless_prev.size()) && (rules == slot.allow_ruless_prev[i])) {
continue;
}
LLAMA_LOG_DEBUG("%s: allowlist %zu is new\n", __func__, i);
auto& biases = slot.allow_biasess[i];
biases.resize(n_vocab);
std::vector<uint32_t> cpts;
std::vector<std::string> scripts;
for (size_t id = 0; id < n_vocab; ++id) {
const size_t n_cpt = llama_fill_from_utf8(&vocab_pieces[id], &cpts, &scripts);
float bias = -INFINITY;
// each codepoint must be found in
for (size_t j = 0; j < n_cpt; ++j) {
bool in_rule = false;
// at least one rule
for (const auto& rule: rules) {
const bool in_range = (std::get<0>(rule) <= cpts[j]) && (cpts[j] <= std::get<1>(rule));
in_rule = in_range && ((std::get<2>(rule) == "*") || std::get<2>(rule) == scripts[j]);
if (in_rule) {
// earlier rule has higher priority
bias = std::max(bias, std::get<3>(rule));
break;
}
}
if (!in_rule) {
if ((scripts[j] == "common") || (scripts[j] == "inherited")) {
// for common or inherited codepoints (e.g. whitespace), defer to other codepoints in the token
continue;
}
// to shadow realm
bias = -INFINITY;
break;
}
}
biases[id] = bias;
}
float max_bias = -INFINITY;
for (const auto& rule: rules) {
max_bias = std::max(max_bias, std::get<3>(rule));
}
for (const auto token: allow_settoken) {
biases[token] = max_bias;
}
}
} while (false);
slot.allow_ruless_prev = slot.allow_ruless;
if (llama_model_has_recurrent(llama_get_model(slot.ctx))) {
params_base.can_ban_phrases = false;
bool do_checkpoint = params_base.ctx_checkpoints_n > 0;
// make checkpoints only for completion tasks
do_checkpoint = do_checkpoint && task.type == SERVER_TASK_TYPE_COMPLETION;
// make a checkpoint of the parts of the memory that cannot be rolled back.
// checkpoints are created only if:
// - the model architecture is marked as recurrent or hybrid
//
// TODO: try to make this conditional on the context or the memory module, instead of the model type
params_base.do_checkpoint = do_checkpoint;
if (slot.n_buffer != 0) {
LLAMA_LOG_WARN("banned strings is not supported by recurrent model, it will be disabled.\n");
}
if (params_base.ctx_shift) {
params_base.ctx_shift = false;
LOG_WARNING("%s\n", "ctx_shift is not supported by recurrent model, it will be disabled");
}
}
{
const auto& stop = data.find("stop");
if (stop != data.end() && stop->is_array()) {
slot.params.antiprompt.clear();
for (const auto& word : *stop) {
if (!word.empty()) {
slot.params.antiprompt.push_back(word);
}
}
}
}
{
const auto samplers = data.find("samplers");
if (samplers != data.end()) {
if (samplers->is_array()) {
slot.sparams.samplers_sequence = llama_sampling_types_from_names(*samplers, false);
}
else if (samplers->is_string()) {
slot.sparams.samplers_sequence = llama_sampling_types_from_chars(samplers->get<std::string>());
}
else {
slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
}
}
}
{
if (slot.ctx_sampling != nullptr) {
common_sampler_free(slot.ctx_sampling);
}
slot.ctx_sampling = common_sampler_init(model, slot.sparams);
if (slot.ctx_sampling == nullptr) {
// for now, the only error that may happen here is invalid grammar
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
slot.command = SLOT_COMMAND_LOAD_PROMPT;
// slot.prompt_tokens.clear();
LOG_INFO("slot is processing task", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
});
slot.task = std::make_unique<const server_task>(std::move(task));
return true;
}
void server_context::kv_cache_clear() {
LOG_VERBOSE("clearing KV cache", {});
// clear the entire KV cache
llama_kv_cache_clear(ctx);
clean_kv_cache = false;
}
void server_context::system_prompt_update() {
LOG_VERBOSE("system prompt update", {
{"system_prompt", system_prompt},
});
kv_cache_clear();
system_tokens.clear();
if (!system_prompt.empty()) {
system_tokens = ::common_tokenize(ctx, system_prompt, true);
const int32_t n_batch = llama_n_batch(ctx);
const int32_t n_tokens_prompt = system_tokens.size();
for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
common_batch_clear(batch);
for (int32_t j = 0; j < n_tokens; ++j) {
common_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
}
if (llama_decode(ctx, batch) != 0) {
LOG_ERROR("llama_decode() failed", {});
return;
}
}
// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= params_base.n_parallel; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
}
system_need_update = false;
}
bool server_context::system_prompt_set(const std::string& sys_prompt) {
system_prompt = sys_prompt;
LOG_VERBOSE("system prompt process", {
{"system_prompt", system_prompt},
});
// release all slots
for (server_slot& slot : slots) {
slot.release();
}
system_need_update = true;
return true;
}
// keep in sync with process_token(completion_token_output& result, server_slot& slot)
bool server_context::has_next_token(const completion_token_output& result, server_slot& slot) {
bool next = true;
//std::string generate_text = slot.generated_text + result.text_to_send;
//bool incomplete = validate_utf8(generate_text) < generate_text.size();
//if (incomplete) {
// next = true;
//}
if (slot.n_decoded > 0 && !slot.has_budget(params_base)) {
next = false;
}
if (llama_token_is_eog(model, result.tok)) {
next = false;
}
auto n_ctx_train = llama_n_ctx_train(model);
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
&& slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
next = false;
}
return next;
}
bool server_context::process_token(completion_token_output& result, server_slot& slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = result.text_to_send;
slot.sampled = result.tok;
// search stop word and delete it
slot.last_gentxt_size = slot.generated_text.size();
slot.generated_text += token_str;
slot.has_next_token = true;
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
// we can change penalty_prompt_tokens because it is always created from scratch each request
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
}
// check if there is incomplete UTF-8 character at the end
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
if (!incomplete) {
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
const std::string str_test = slot.generated_text.substr(pos);
bool send_text = true;
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
if (stop_pos != std::string::npos) {
slot.generated_text.erase(
slot.generated_text.begin() + pos + stop_pos,
slot.generated_text.end());
pos = std::min(slot.n_sent_text, slot.generated_text.size());
}
else if (slot.has_next_token && !llama_token_is_eog(model, result.tok)) {
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
send_text = stop_pos == std::string::npos;
}
// check if there is any token to predict
if (send_text) {
// no send the stop word in the response
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
slot.n_sent_text += result.text_to_send.size();
// add the token to slot queue and cache
}
else {
result.text_to_send = "";
}
slot.add_token_string(result);
if (slot.params.stream) {
send_partial_response(slot, result);
}
}
if (incomplete) {
slot.has_next_token = true;
}
// check the limits
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
slot.stopped_limit = true;
slot.has_next_token = false;
LOG_VERBOSE("stopped by limit", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_decoded", slot.n_decoded},
{"n_predict", slot.params.n_predict},
});
}
if (llama_token_is_eog(model, result.tok)) {
slot.stopped_eos = true;
slot.has_next_token = false;
LOG_VERBOSE("eos token found", {});
}
auto n_ctx_train = llama_n_ctx_train(model);
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
&& slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
LOG_WARNING("n_predict is not set and self-context extend is disabled."
" Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", {
{ "id_slot", slot.id },
{ "params.n_predict", slot.params.n_predict },
{ "slot.n_prompt_tokens", slot.n_prompt_tokens },
{ "slot.n_decoded", slot.n_decoded },
{ "slot.n_predict", slot.n_predict },
{ "n_slots", params_base.n_parallel },
{ "slot.n_ctx", slot.n_ctx },
{ "n_ctx", n_ctx },
{ "n_ctx_train", n_ctx_train },
{ "ga_n", slot.ga_n },
});
slot.truncated = true;
slot.stopped_limit = true;
slot.has_next_token = false; // stop prediction
}
log_text(params_base, "token:"+result.text_to_send);
LOG_VERBOSE("next token", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"token", result.tok},
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
{"has_next_token", slot.has_next_token},
{"n_remain", slot.n_remaining},
{"n_decoded", slot.n_decoded},
{"stopped_eos", slot.stopped_eos},
{"stopped_word", slot.stopped_word},
{"stopped_limit", slot.stopped_limit},
{"stopping_word", slot.stopping_word},
});
return slot.has_next_token; // continue
}
void server_context::populate_token_probs(const server_slot& slot, completion_token_output& result, bool post_sampling, bool special, int idx) {
size_t n_probs = slot.sparams.n_probs;
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
if (post_sampling) {
const auto* cur_p = common_sampler_get_candidates(slot.ctx_sampling);
const size_t max_probs = cur_p->size;
// set probability for sampled token
for (size_t i = 0; i < max_probs; i++) {
if (cur_p->data[i].id == result.tok) {
result.prob = cur_p->data[i].p;
break;
}
}
// set probability for top n_probs tokens
result.probs.reserve(max_probs);
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
result.probs.push_back({
cur_p->data[i].id,
common_token_to_piece(ctx, cur_p->data[i].id, special),
cur_p->data[i].p
});
}
}
else {
auto&& [sampled_token_p, cur] = get_token_probabilities(ctx, idx, result.tok, n_probs);
// set probability for sampled token
result.prob = sampled_token_p;
// set probability for top n_probs tokens
result.probs.reserve(n_probs);
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
result.probs.push_back({
cur[i].id,
common_token_to_piece(ctx, cur[i].id, special),
cur[i].p
});
}
}
}
json server_context::get_formated_generation(const server_slot& slot) const {
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
std::vector<std::string> samplers_sequence;
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
for (const auto& sampler_type : slot.sparams.samplers_sequence) {
samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
}
auto grammar_triggers = json::array();
for (const auto& trigger : slot.sparams.grammar_triggers) {
grammar_triggers.push_back(trigger.to_json<json>());
}
return json{
{"n_ctx", slot.n_ctx},
{"n_predict", slot.n_predict}, // Server configured n_predict
{"model", params_base.model_alias},
{"seed", slot.sparams.seed},
{"temperature", slot.sparams.temp},
{"dynatemp_range", slot.sparams.dynatemp_range},
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
{"top_k", slot.sparams.top_k},
{"top_p", slot.sparams.top_p},
{"min_p", slot.sparams.min_p},
{"tfs_z", slot.sparams.tfs_z},
{"typical_p", slot.sparams.typical_p},
{"repeat_last_n", slot.sparams.penalty_last_n},
{"repeat_penalty", slot.sparams.penalty_repeat},
{"presence_penalty", slot.sparams.penalty_present},
{"frequency_penalty", slot.sparams.penalty_freq},
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
{"dry_multiplier", slot.sparams.dry_multiplier},
{"dry_base", slot.sparams.dry_base},
{"dry_allowed_length", slot.sparams.dry_allowed_length},
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
{"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
{"mirostat", slot.sparams.mirostat},
{"mirostat_tau", slot.sparams.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta},
{"adaptive_target", slot.sparams.adaptive_target},
{"adaptive_decay", slot.sparams.adaptive_decay},
{"adaptive_updt_w_cur", slot.sparams.adaptive_updt_w_cur},
{"penalize_nl", slot.sparams.penalize_nl},
{"stop", slot.params.antiprompt},
{"max_tokens", slot.params.n_predict}, // User configured n_predict
{"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard},
{"ignore_eos", ignore_eos},
{"stream", slot.params.stream},
{"logit_bias", slot.sparams.logit_bias},
{"n_probs", slot.sparams.n_probs},
{"min_keep", slot.sparams.min_keep},
{"grammar", slot.sparams.grammar},
{"grammar_triggers", grammar_triggers},
{"preserved_tokens", slot.sparams.preserved_tokens},
{"chat_format", common_chat_format_name(slot.params.oaicompat_chat_syntax.format)},
{"reasoning_format", common_reasoning_format_name(slot.params.oaicompat_chat_syntax.reasoning_format)},
{"reasoning_in_content", slot.params.oaicompat_chat_syntax.reasoning_in_content},
{"thinking_forced_open", slot.params.oaicompat_chat_syntax.thinking_forced_open},
{"samplers", samplers_sequence}
};
}
void server_context::send_error(const server_task& task, const std::string& error, const enum error_type type) {
send_error(task.id, task.id_multi, error, type);
}
void server_context::send_error(const server_slot& slot, const std::string& error, const enum error_type type) {
send_error(slot.id_task, slot.id_multi, error, type);
}
void server_context::send_error(const int id_task, const int id_multi, const std::string& error, const enum error_type type ) {
LOG_ERROR("task error", {
{"id_multi", id_multi},
{"id_task", id_task},
{"error", error},
});
auto res = std::make_unique<server_task_result_error>();
res->id = id_task;
res->id_multi = id_multi;
res->stop = false;
res->error = true;
res->err_type = type;
res->err_msg = error;
queue_results.send(std::move(res));
}
// if multimodal is enabled, send an error and return false
bool server_context::check_no_mtmd(const int id_task) {
if (mctx) {
int id_multi = 0;
send_error(id_task, id_multi, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
return false;
}
return true;
}
void server_context::send_partial_response(server_slot& slot, completion_token_output tkn) {
if (slot.task == nullptr) {
return;
}
auto res = std::make_unique<server_task_result_cmpl_partial>();
res->final_result = false;
res->id = slot.id_task;
res->id_multi = slot.id_multi;
res->index = slot.task->index;
res->error = false;
res->stop = false;
res->stream = slot.params.stream;
res->content = tkn.text_to_send;
res->post_sampling_probs = slot.params.post_sampling_probs;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.task->params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oai_resp_id = slot.oai_resp_id;
res->oai_resp_reasoning_id = slot.oai_resp_reasoning_id;
res->oai_resp_message_id = slot.oai_resp_message_id;
res->oai_resp_fc_id = slot.oai_resp_fc_id;
res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens;
res->data = json{
{"content", tkn.text_to_send},
{"stop", false},
{"id_slot", slot.id},
{"multimodal", false}
};
slot.update_chat_msg(res->oaicompat_msg_diffs);
res->anthropic_has_reasoning = !slot.chat_msg.reasoning_content.empty();
res->anthropic_thinking_block_started = slot.anthropic_thinking_block_started;
res->anthropic_text_block_started = slot.anthropic_text_block_started;
res->oai_resp_thinking_block_started = slot.oai_resp_thinking_block_started;
res->oai_resp_text_block_started = slot.oai_resp_text_block_started;
for (const auto& diff : res->oaicompat_msg_diffs) {
if (!diff.reasoning_content_delta.empty() && !slot.anthropic_thinking_block_started) {
slot.anthropic_thinking_block_started = true;
}
if (!diff.content_delta.empty() && !slot.anthropic_text_block_started) {
slot.anthropic_text_block_started = true;
}
if (!diff.reasoning_content_delta.empty() && !slot.oai_resp_thinking_block_started) {
slot.oai_resp_thinking_block_started = true;
}
if (!diff.content_delta.empty() && !slot.oai_resp_text_block_started) {
slot.oai_resp_text_block_started = true;
}
if (!diff.tool_call_delta.name.empty()) {
slot.oai_resp_fc_id = diff.tool_call_delta.id;
}
}
// populate res->probs_output
if (slot.sparams.n_probs > 0) {
res->probs_output = { tkn }; // copy the token probs
res->data["completion_probabilities"] = probs_vector_to_json(ctx, res->probs_output);
}
if (slot.oaicompat) {
res->data["oaicompat_token_ctr"] = slot.n_decoded;
res->data["model"] = slot.oaicompat_model;
}
// populate timings if this is final response or timings_per_token is enabled
if (slot.params.timings_per_token) {
res->timings = slot.get_timings();
}
queue_results.send(std::move(res));
}
void server_context::send_final_response(server_slot& slot) {
auto res = std::make_unique<server_task_result_cmpl_final>();
res->final_result = true;
res->id = slot.id_task;
res->id_multi = slot.id_multi;
res->index = slot.task->index;
res->error = false;
res->stop = true; // to do: set value
res->stream = slot.params.stream;
res->include_usage = slot.params.include_usage;
res->content = slot.generated_text;
res->timings = slot.get_timings();
res->post_sampling_probs = slot.params.post_sampling_probs;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
res->oai_resp_id = slot.oai_resp_id;
res->oai_resp_reasoning_id = slot.oai_resp_reasoning_id;
res->oai_resp_message_id = slot.oai_resp_message_id;
res->n_decoded = slot.n_decoded;
res->anthropic_thinking_block_started = slot.anthropic_thinking_block_started;
res->anthropic_text_block_started = slot.anthropic_text_block_started;
res->n_prompt_tokens = slot.n_prompt_tokens;
res->oaicompat_model = slot.task->params.oaicompat_model;
res->data = json{
{"content", !slot.params.stream ? slot.generated_text : ""},
{"generated_text", slot.generated_text}, // Always include full text for finish_reason logic
{"id_slot", slot.id},
{"stop", true},
{"model", params_base.model_alias},
{"tokens_predicted", slot.n_decoded},
{"tokens_evaluated", slot.n_prompt_tokens},
{"generation_settings", get_formated_generation(slot)},
{"prompt", slot.prompt},
{"truncated", slot.truncated},
{"stopped_eos", slot.stopped_eos},
{"stopped_word", slot.stopped_word},
{"stopped_limit", slot.stopped_limit},
{"stopping_word", slot.stopping_word},
{"tokens_cached", slot.n_past},
{"timings", slot.get_formated_timings()},
//{"oaicompat_chat_format", slot.params.oaicompat_chat_format},
};
// populate res->probs_output
if (slot.sparams.n_probs > 0) {
res->probs_output = std::vector<completion_token_output>(
slot.generated_token_probs.begin(),
slot.generated_token_probs.end());
res->data["completion_probabilities"] = probs_vector_to_json(ctx, res->probs_output);
}
if (slot.oaicompat) {
res->data["oaicompat_token_ctr"] = slot.n_decoded;
res->data["model"] = slot.oaicompat_model;
}
queue_results.send(std::move(res));
}
void server_context::send_embedding(const server_slot& slot, const llama_batch& batch) {
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.task->id;
res->index = slot.task->index;
res->server_task_result::index = slot.task->index;
res->n_tokens = slot.prompt_tokens.size();
res->oaicompat = slot.task->params.oaicompat;
const int n_embd = llama_model_n_embd(model);
std::vector<float> embd_res(n_embd, 0.0f);
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
}
const float* embd = nullptr;
if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
embd = llama_get_embeddings_ith(ctx, i);
}
else {
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
}
if (embd == nullptr) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
continue;
}
// normalize only when there is pooling
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize);
res->embedding.push_back(embd_res);
break;
}
res->embedding.emplace_back(embd, embd + n_embd);
}
queue_results.send(std::move(res));
}
void server_context::apply_server_biases(server_slot& slot) {
auto& server_biases = slot.ctx_sampling->server_biases;
if (slot.allow_idx < slot.allow_biasess.size()) {
server_biases = &slot.allow_biasess[slot.allow_idx];
} else {
server_biases = nullptr;
}
}
void server_context::request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs) {
server_task task;
task.id = id_task;
task.id_multi = id_multi;
task.id_target = 0;
task.data = std::move(data);
task.infill = infill;
task.embedding = embedding;
task.type = SERVER_TASK_TYPE_COMPLETION;
task.tokens = std::move(inputs);
// when a completion task's prompt array is not a singleton, we split it into multiple requests
// otherwise, it's a single-prompt task, we actually queue it
// if there's numbers in the prompt array it will be treated as an array of tokens
if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
bool numbers = false;
for (const auto& e : task.data.at("prompt")) {
if (e.is_number()) {
numbers = true;
break;
}
}
// NOTE: split_multiprompt_task() does not handle a mix of strings and numbers,
// it will completely stall the server. I don't know where the bug for this is.
//
// if there are numbers, it needs to be treated like a single prompt,
// queue_tasks handles a mix of strings and numbers just fine.
if (numbers) {
queue_tasks.post(std::move(task));
}
else {
split_multiprompt_task(id_task, task);
}
}
else {
queue_tasks.post(std::move(task));
}
}
void server_context::request_cancel(int id_task) {
server_task task;
task.type = SERVER_TASK_TYPE_CANCEL;
task.id_target = id_task;
queue_tasks.post(std::move(task));
}
void server_context::split_multiprompt_task(int id_multi, server_task& multiprompt_task) {
const int prompt_count = multiprompt_task.data.at("prompt").size();
if (prompt_count <= 1) {
send_error(multiprompt_task, "error while handling multiple prompts");
return;
}
// generate all the ID for subtask
std::vector<int> subtask_ids(prompt_count);
for (int i = 0; i < prompt_count; i++) {
subtask_ids[i] = queue_tasks.get_new_id();
}
// queue up the multitask so we can track its subtask progression
queue_tasks.add_multitask(id_multi, subtask_ids);
// add subtasks
for (int i = 0; i < prompt_count; i++) {
json subtask_data = multiprompt_task.data;
subtask_data["prompt"] = subtask_data.at("prompt")[i];
// subtasks inherit everything else (infill mode, embedding mode, etc.)
request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding,
std::move(multiprompt_task.tokens));
}
}
static size_t save_checkpoints_to_file(const std::string & filename, const std::list<server_prompt_checkpoint> & checkpoints) {
if (checkpoints.size() == 0) {
return 0;
}
std::ofstream file(filename, std::ios::binary);
uint32_t magic = LLAMA_STATE_SEQ_MAGIC;
file.write(reinterpret_cast<const char *>(&magic), sizeof(magic));
uint32_t version = LLAMA_STATE_SEQ_VERSION;
file.write(reinterpret_cast<const char *>(&version), sizeof(version));
size_t count = checkpoints.size();
file.write(reinterpret_cast<const char *>(&count), sizeof(count));
for (const auto & checkpoint : checkpoints) {
file.write(reinterpret_cast<const char *>(&checkpoint.pos_min), sizeof(checkpoint.pos_min));
file.write(reinterpret_cast<const char *>(&checkpoint.pos_max), sizeof(checkpoint.pos_max));
file.write(reinterpret_cast<const char *>(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt));
file.write(reinterpret_cast<const char *>(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt));
size_t data_len = checkpoint.data.size();
file.write(reinterpret_cast<const char *>(&data_len), sizeof(data_len));
if (data_len > 0) {
file.write(reinterpret_cast<const char *>(checkpoint.data.data()), data_len * sizeof(uint8_t));
}
}
size_t pos = file.tellp();
file.close();
return pos;
}
static size_t load_checkpoints_from_file(const std::string & filename, std::list<server_prompt_checkpoint> & checkpoints) {
std::ifstream file(filename, std::ios::binary);
if (!file.is_open()) {
return 0;
}
checkpoints.clear();
// version checks
{
uint32_t magic;
file.read(reinterpret_cast<char *>(&magic), sizeof(magic));
uint32_t version;
file.read(reinterpret_cast<char *>(&version), sizeof(version));
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
LLAMA_LOG_ERROR("%s: unknown (magic, version) for checkpoint file: %08x, %08x\n", __func__, magic, version);
return 0;
}
}
// load the checkpoints
{
size_t count;
file.read(reinterpret_cast<char *>(&count), sizeof(count));
for (int i = 0; i < count; i++) {
server_prompt_checkpoint checkpoint;
file.read(reinterpret_cast<char *>(&checkpoint.pos_min), sizeof(checkpoint.pos_min));
file.read(reinterpret_cast<char *>(&checkpoint.pos_max), sizeof(checkpoint.pos_max));
file.read(reinterpret_cast<char *>(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt));
file.read(reinterpret_cast<char *>(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt));
size_t data_len;
file.read(reinterpret_cast<char *>(&data_len), sizeof(data_len));
if (data_len > 0) {
checkpoint.data.resize(data_len);
file.read(reinterpret_cast<char *>(checkpoint.data.data()), data_len * sizeof(uint8_t));
}
checkpoints.push_back(checkpoint);
}
}
size_t pos = file.tellg();
file.close();
return pos;
}
static size_t save_server_tokens_to_file(const std::string & filename, const server_tokens & tokens) {
std::ofstream file(filename, std::ios::binary);
json token_json = tokens.to_json();
token_json["magic"] = LLAMA_SERVER_MAGIC;
token_json["version"] = LLAMA_SERVER_VERSION;
size_t pos = 0;
if (file.is_open()) {
file << token_json;
pos = file.tellp();
file.close();
}
return pos;
}
static size_t load_server_tokens_from_file(const std::string & filename, server_tokens & tokens) {
std::ifstream file(filename, std::ios::binary);
if (!file.is_open()) {
return 0;
}
size_t pos = 0;
json token_json;
if (file.is_open()) {
file >> token_json;
pos = file.tellg();
file.close();
}
uint32_t magic = token_json.value<uint32_t>("magic", 0);
uint32_t version = token_json.value<uint32_t>("version", 0);
if (magic != LLAMA_SERVER_MAGIC || version != LLAMA_SERVER_VERSION) {
LLAMA_LOG_ERROR("%s: unknown (magic, version) for token file: %08x, %08x\n", __func__, magic, version);
return 0;
}
tokens.from_json(token_json);
return pos;
}
void server_context::process_single_task(server_task&& task) {
switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION:
case SERVER_TASK_TYPE_INFILL:
case SERVER_TASK_TYPE_EMBEDDING:
case SERVER_TASK_TYPE_RERANK:
{
const int id_slot = json_value(task.data, "id_slot", -1);
server_slot* slot;
if (id_slot != -1) {
slot = get_slot_by_id(id_slot);
}
else {
slot = get_available_slot(task);
}
if (slot == nullptr) {
// if no slot is available, we defer this task for processing later
LOG_VERBOSE("no slot is available", { {"id_task", task.id} });
queue_tasks.defer(std::move(task));
break;
}
if (!slot->available()) {
// if requested slot is unavailable, we defer this task for processing later
LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} });
queue_tasks.defer(std::move(task));
break;
}
if (task.data.contains("system_prompt")) {
std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
system_prompt_set(sys_prompt);
for (server_slot& slot : slots) {
slot.n_past = 0;
slot.n_past_se = 0;
}
}
slot->reset();
slot->id_task = task.id;
slot->id_multi = task.id_multi;
slot->infill = task.infill;
slot->embedding = task.embedding;
if (!launch_slot_with_task(*slot, task)) {
LOG_ERROR("error while launching slot", task.data);
break;
}
} break;
case SERVER_TASK_TYPE_CANCEL:
{
// release slot linked with the task id
for (auto& slot : slots) {
if (slot.id_task == task.id_target) {
slot.release();
break;
}
}
} break;
case SERVER_TASK_TYPE_NEXT_RESPONSE:
{
// do nothing
} break;
case SERVER_TASK_TYPE_METRICS:
{
json slots_data = json::array();
int n_idle_slots = 0;
int n_processing_slots = 0;
for (server_slot& slot : slots) {
json slot_data = get_formated_generation(slot);
slot_data["id"] = slot.id;
slot_data["id_task"] = slot.id_task;
slot_data["state"] = slot.state;
slot_data["prompt"] = slot.prompt;
slot_data["next_token"] = {
{"has_next_token", slot.has_next_token},
{"n_remain", slot.n_remaining},
{"n_decoded", slot.n_decoded},
{"stopped_eos", slot.stopped_eos},
{"stopped_word", slot.stopped_word},
{"stopped_limit", slot.stopped_limit},
{"stopping_word", slot.stopping_word},
};
if (slot_data["state"] == SLOT_STATE_IDLE) {
n_idle_slots++;
}
else {
n_processing_slots++;
}
slots_data.push_back(slot_data);
}
LOG_INFO("slot data", {
{"id_task", task.id},
{"n_idle_slots", n_idle_slots},
{"n_processing_slots", n_processing_slots}
});
LOG_VERBOSE("slot data", {
{"id_task", task.id},
{"n_idle_slots", n_idle_slots},
{"n_processing_slots", n_processing_slots},
{"slots", slots_data}
});
server_task_result res;
res.id = task.id;
res.id_multi = task.id_multi;
res.stop = true;
res.error = false;
res.data = {
{ "idle", n_idle_slots },
{ "processing", n_processing_slots },
{ "deferred", queue_tasks.queue_tasks_deferred.size() },
{ "t_start", metrics.t_start},
{ "n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total},
{ "t_tokens_generation_total", metrics.t_tokens_generation_total},
{ "n_tokens_predicted_total", metrics.n_tokens_predicted_total},
{ "t_prompt_processing_total", metrics.t_prompt_processing_total},
{ "n_prompt_tokens_processed", metrics.n_prompt_tokens_processed},
{ "t_prompt_processing", metrics.t_prompt_processing},
{ "n_tokens_predicted", metrics.n_tokens_predicted},
{ "t_tokens_generation", metrics.t_tokens_generation},
{ "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)},
{ "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)},
{ "slots", slots_data },
};
if (json_value(task.data, "reset_bucket", false)) {
metrics.reset_bucket();
}
queue_results.send(res);
} break;
case SERVER_TASK_TYPE_SLOT_SAVE:
{
int id_slot = task.data.at("id_slot");
server_slot* slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
break;
}
if (!slot->available()) {
// if requested slot is unavailable, we defer this task for processing later
LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} });
queue_tasks.defer(std::move(task));
break;
}
const size_t token_count = slot->cache_tokens.size();
const int64_t t_start = ggml_time_us();
std::string filename = task.data.at("filename");
std::string filepath = task.data.at("filepath");
save_server_tokens_to_file(filepath+".tokens.json", slot->cache_tokens);
size_t saved = save_checkpoints_to_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints);
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
const int64_t t_end = ggml_time_us();
const double t_save_ms = (t_end - t_start) / 1000.0;
server_task_result result;
result.id = task.id;
result.stop = true;
result.error = false;
result.data = json{
{ "id_slot", id_slot },
{ "filename", filename },
{ "n_saved", token_count }, // tokens saved
{ "n_written", nwrite + saved }, // bytes written
{ "timings", {
{ "save_ms", t_save_ms }
} }
};
queue_results.send(result);
} break;
case SERVER_TASK_TYPE_SLOT_RESTORE:
{
int id_slot = task.data.at("id_slot");
server_slot* slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
break;
}
if (!slot->available()) {
// if requested slot is unavailable, we defer this task for processing later
LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} });
queue_tasks.defer(std::move(task));
break;
}
const int64_t t_start = ggml_time_us();
std::string filename = task.data.at("filename");
std::string filepath = task.data.at("filepath");
slot->cache_tokens.resize(slot->n_ctx);
size_t token_count = 0;
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
if (nread == 0) {
slot->cache_tokens.resize(0);
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
break;
}
load_server_tokens_from_file(filepath+".tokens.json", slot->cache_tokens);
size_t loaded = load_checkpoints_from_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints);
const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0;
server_task_result result;
result.id = task.id;
result.stop = true;
result.error = false;
result.data = json{
{ "id_slot", id_slot },
{ "filename", filename },
{ "n_restored", token_count }, // tokens restored
{ "n_read", nread }, // bytes read
{ "timings", {
{ "restore_ms", t_restore_ms }
} }
};
queue_results.send(result);
} break;
case SERVER_TASK_TYPE_SLOT_ERASE:
{
int id_slot = task.data.at("id_slot");
server_slot* slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
break;
}
if (!slot->available()) {
// if requested slot is unavailable, we defer this task for processing later
LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} });
queue_tasks.defer(std::move(task));
break;
}
// Erase token cache
const size_t n_erased = slot->cache_tokens.size();
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
slot->cache_tokens.keep_first(0);
//slot->cache_tokens.clear();
slot->server_cached_prompt.checkpoints.clear();
slot->server_cached_prompt.data.clear();
server_task_result result;
result.id = task.id;
result.stop = true;
result.error = false;
result.data = json{
{ "id_slot", id_slot },
{ "n_erased", n_erased }
};
queue_results.send(result);
} break;
case SERVER_TASK_TYPE_SET_LORA:
{
llama_lora_adapters_apply(ctx, lora_adapters);
server_task_result result;
result.id = task.id;
result.stop = true;
result.error = false;
result.data = json{ { "success", true } };
queue_results.send(result);
} break;
case SERVER_TASK_TYPE_LOAD_CONTROL_VECTOR:
{
// Load control vector from file
std::string path = task.data.at("path");
float scale = task.data.value("scale", 1.0f);
int32_t layer_start = task.data.value("layer_start", 1);
int32_t layer_end = task.data.value("layer_end", llama_n_layer(model));
// Check if already loaded
int cv_id = -1;
for (size_t i = 0; i < control_vectors.size(); i++) {
if (control_vectors[i].path == path) {
control_vectors[i].scale = scale;
control_vectors[i].layer_start = layer_start;
control_vectors[i].layer_end = layer_end;
cv_id = i;
break;
}
}
if (cv_id == -1) {
control_vector_container new_cv;
new_cv.path = path;
new_cv.scale = scale;
new_cv.layer_start = layer_start;
new_cv.layer_end = layer_end;
new_cv.applied = false;
// Load the control vector data
llama_control_vector_load_info load_info;
load_info.fname = path;
load_info.strength = 1.0f; // Don't pre-scale here, we'll scale when applying
std::vector<llama_control_vector_load_info> load_infos = { load_info };
new_cv.data = llama_control_vector_load(load_infos);
if (new_cv.data.n_embd == -1) {
server_task_result result;
result.id = task.id;
result.error = true;
result.data = json{{ "success", false }, { "error", "Failed to load control vector from " + path }};
queue_results.send(result);
break;
}
// Validate dimension to prevent heap corruption
if (new_cv.data.n_embd != llama_model_n_embd(model)) {
server_task_result result;
result.id = task.id;
result.error = true;
result.data = json{{ "success", false },
{ "error", "Vector dimension mismatch" }};
queue_results.send(result);
break;
}
control_vectors.push_back(new_cv);
cv_id = control_vectors.size() - 1;
}
// Auto-apply control vectors after loading
if (!apply_control_vectors_internal()) {
server_task_result result;
result.id = task.id;
result.error = true;
result.data = json{{ "success", false }, { "error", "Failed to apply control vectors" }};
queue_results.send(result);
break;
}
server_task_result result;
result.id = task.id;
result.error = false;
result.data = json{{ "success", true }, { "id", cv_id }};
queue_results.send(result);
} break;
case SERVER_TASK_TYPE_UNLOAD_CONTROL_VECTOR:
{
// Validate that "id" field exists and is a number
if (!task.data.contains("id") || task.data["id"].is_null() || !task.data["id"].is_number()) {
server_task_result result;
result.id = task.id;
result.error = true;
result.data = json{{ "success", false }, { "error", "Missing or invalid 'id' field" }};
queue_results.send(result);
break;
}
int id = task.data.at("id");
if (id < 0 || id >= (int)control_vectors.size()) {
server_task_result result;
result.id = task.id;
result.error = true;
result.data = json{{ "success", false }, { "error", "Invalid control vector ID" }};
queue_results.send(result);
break;
}
// Remove the control vector from the list
control_vectors.erase(control_vectors.begin() + id);
// Reapply remaining control vectors
if (!apply_control_vectors_internal()) {
server_task_result result;
result.id = task.id;
result.error = true;
result.data = json{{ "success", false }, { "error", "Failed to apply control vectors" }};
queue_results.send(result);
break;
}
server_task_result result;
result.id = task.id;
result.error = false;
result.data = json{{ "success", true }};
queue_results.send(result);
} break;
case SERVER_TASK_TYPE_SET_CONTROL_VECTOR:
{
if (!apply_control_vectors_internal()) {
server_task_result result;
result.id = task.id;
result.error = true;
result.data = json{{ "success", false }, { "error", "Failed to apply control vectors" }};
queue_results.send(result);
break;
}
server_task_result result;
result.id = task.id;
result.error = false;
result.data = json{{ "success", true }};
queue_results.send(result);
} break;
}
}
bool server_context::apply_control_vectors_internal() {
llama_control_vector_data combined_cv = { -1, {} };
// Check if we have anything to apply
bool any_active = false;
for (const auto& cv : control_vectors) {
if (cv.scale != 0.0f) {
any_active = true;
break;
}
}
if (!any_active) {
// Clear control vectors if nothing is active
llama_control_vector_apply(ctx, nullptr, 0, 0, 0, 0);
return true;
}
// Aggregate control vectors with scaling
for (auto& cv : control_vectors) {
if (cv.scale == 0.0f) {
cv.applied = false;
continue;
}
if (combined_cv.n_embd == -1) {
combined_cv.n_embd = cv.data.n_embd;
combined_cv.data.resize(cv.data.data.size(), 0.0f);
}
for (size_t i = 0; i < cv.data.data.size(); i++) {
combined_cv.data[i] += cv.data.data[i] * cv.scale;
}
cv.applied = true;
}
// Apply combined control vector
if (combined_cv.n_embd != -1 && !combined_cv.data.empty()) {
int32_t min_layer_start = INT32_MAX;
int32_t max_layer_end = 0;
for (const auto& cv : control_vectors) {
if (cv.scale != 0.0f) {
min_layer_start = std::min(min_layer_start, cv.layer_start);
max_layer_end = std::max(max_layer_end, cv.layer_end);
}
}
int err = llama_control_vector_apply(ctx,
combined_cv.data.data(),
combined_cv.data.size(),
combined_cv.n_embd,
min_layer_start,
max_layer_end);
return (err == 0);
}
return true;
}
void server_context::on_finish_multitask(const server_task_multi& multitask) {
// all subtasks done == multitask is done
server_task_result result;
result.id = multitask.id;
result.stop = true;
result.error = false;
// collect json results into one json result
std::vector<json> result_jsons;
for (const auto& subres : multitask.results) {
result_jsons.push_back(subres.data);
result.error = result.error && subres.error;
}
result.data = json{
{ "results", result_jsons }
};
queue_results.send(result);
}
void server_context::print_tokens(const server_tokens& prompt, const server_tokens& cache, size_t start1, size_t start2, size_t length) {
if (cache.size() > start2) {
LLAMA_LOG_INFO("cache : %s\n", cache.detokenize(ctx, true, start2, length).c_str());
}
if (prompt.size() > start1) {
LLAMA_LOG_INFO("prompt: %s\n", prompt.detokenize(ctx, true, start1, length).c_str());
}
}
void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard) {
auto kv_keep = slot.cache_tokens.pos_next(n_keep);
auto kv_discard = slot.cache_tokens.pos_next(n_keep + n_discard) - kv_keep;
auto kv_past = slot.cache_tokens.pos_next(slot.n_past);
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
llama_kv_cache_seq_rm(ctx, slot.id, kv_keep, kv_keep + kv_discard);
llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard);
if (slot.has_mtp && slot.spec) {
common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past);
}
if (slot.params.cache_prompt) {
slot.cache_tokens.discard_n_tokens(n_keep, n_discard);
}
}
inline static bool tokens_support_context_shift(const server_tokens & tokens, int32_t n_keep,
int32_t n_discard) {
bool can_shift = !tokens.has_mtmd;
if (tokens.has_mtmd) {
can_shift = true;
if (n_keep > 0 && n_keep<= tokens.n_tokens()) {
can_shift = tokens[n_keep - 1] != LLAMA_TOKEN_NULL;
}
if (n_discard + n_keep > 0 && n_discard + n_keep <= tokens.n_tokens()) {
can_shift = can_shift && tokens[n_discard + n_keep - 1] != LLAMA_TOKEN_NULL;
}
}
return can_shift;
}
inline static void adjust_n_to_support_context_shift(const server_tokens & tokens, int32_t & n_keep,
int32_t & n_discard) {
if (!tokens.has_mtmd) {
return;
}
if (n_keep > 0 && n_keep <= tokens.n_tokens()) {
while (tokens[n_keep - 1] == LLAMA_TOKEN_NULL) {
n_keep--;
if (n_keep<1 || n_keep>tokens.size()) {
break;
}
}
}
if (n_discard + n_keep > 0 && n_discard + n_keep <= tokens.n_tokens()) {
while (tokens[n_discard + n_keep - 1] == LLAMA_TOKEN_NULL) {
n_discard++;
if (n_discard + n_keep<1 || n_discard + n_keep>tokens.size()) {
break;
}
}
}
}
// convert keep first few and discard next tokens in a to b
void server_context::context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep,
int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact) {
common_prefix ctx_keep_prefix = a.get_common_prefix_first_n(ctx, b, n_keep, exact);
common_prefix ctx_total_discard_prefix = a.get_common_prefix_first_n(ctx, b, n_discard + n_keep, exact);
// only if there is enough common token
int32_t discard_offset = ctx_total_discard_prefix.first - (n_discard + n_keep);
int32_t keep_offset = ctx_keep_prefix.first - n_keep;
n_kept = ctx_keep_prefix.second - keep_offset;
n_discarded = ctx_total_discard_prefix.second - ctx_keep_prefix.second - discard_offset;
if (n_kept < 0) {
n_kept = n_keep;
}
if (n_discarded < 0) {
n_discarded = n_discard;
}
}
void server_context::context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact) {
int n_keep = std::max(0, slot.params.n_keep + add_bos_token);
const int n_left = slot.n_ctx - n_keep;
int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
adjust_n_to_support_context_shift(slot.prompt_tokens, n_keep, n_discard);
if (n_discard<=0 || !tokens_support_context_shift(slot.prompt_tokens, n_keep, n_discard)) {
return;
}
int n_discard_prompt = 0;
// we still need to truncate input since we have not discarded enough tokens
while (slot.n_prompt_tokens - slot.n_discarded_prompt >= slot.n_ctx) {
slot.n_discarded_prompt = slot.n_discarded_prompt + n_discard;
n_discard_prompt = n_discard_prompt + n_discard;
}
// Handle mistokenization between prompt and cache during context shift
//
int32_t n_discard_cache = n_discard_prompt;
int32_t n_kept = n_keep;
slot.prompt_tokens.discard_n_tokens(n_keep, slot.n_discarded_prompt - n_discard_prompt);
if (n_discard_prompt > 0) {
context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep,
n_discard, n_kept, n_discard_cache, exact);
}
int n_discard_cache_max = std::max((int32_t)slot.cache_tokens.size() - n_kept, 0);
n_discard_cache = std::min(n_discard_cache, n_discard_cache_max);
// discard matching tokens from cache and kv cache to avoid reprocessing the prompt
if (n_discard_cache > 0) {
discard_n_kv_and_cache_tokens(ctx, slot, n_kept, n_discard_cache);
}
// discard extra tokens from prompts
slot.n_kept_prompt = n_keep;
slot.prompt_tokens.discard_n_tokens(n_keep, n_discard_prompt);
slot.n_prompt_tokens = slot.prompt_tokens.size();
}
void server_context::release_slots()
{
for (auto& slot : slots) {
if (slot.command == SLOT_COMMAND_RELEASE) {
slot.state = SLOT_STATE_IDLE;
slot.command = SLOT_COMMAND_NONE;
slot.t_last_used = ggml_time_us();
LOG_INFO("slot released", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_ctx", n_ctx},
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()},
{"truncated", slot.truncated}
});
queue_tasks.notify_slot_changed();
}
}
}
bool server_context::slots_idle(){
bool all_idle = true;
for (auto& slot : slots) {
if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) {
all_idle = false;
break;
}
}
if (all_idle) {
LOG_INFO("all slots are idle", {});
if (system_prompt.empty() && clean_kv_cache) {
kv_cache_clear();
}
all_idle = true;
}
return all_idle;
}
void server_context::context_shift() {
for (server_slot& slot : slots) {
if (slot.ga_n == 1) {
if (slot.is_processing() && (int)system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
if (!params_base.ctx_shift) {
// this check is redundant (for good)
// we should never get here, because generation should already stopped in process_token()
slot.print_timings();
slot.release();
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
continue;
}
// Shift context
int n_keep = slot.params.n_keep < 0 ? slot.prompt_tokens.size() : slot.params.n_keep;
if (add_bos_token) {
n_keep += 1;
}
n_keep = std::min(slot.n_ctx - 4, n_keep);
const int32_t n_left = (int)system_tokens.size() + slot.n_past - n_keep;
int32_t n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
int32_t n_kept;
int32_t n_discard_cache;
adjust_n_to_support_context_shift(slot.cache_tokens, n_keep, n_discard);
if (n_discard > 0 && tokens_support_context_shift(slot.cache_tokens, n_keep, n_discard)) {
context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep,
n_discard, n_kept, n_discard_cache);
LOG_INFO("slot context shift", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_keep", n_keep},
{"n_left", n_left},
{"n_discard", n_discard},
{"n_ctx", n_ctx},
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()}
});
slot.n_discarded_prompt = slot.n_discarded_prompt + n_discard;
slot.n_kept_prompt = n_keep;
discard_n_kv_and_cache_tokens(ctx, slot, n_kept, n_discard_cache);
slot.n_past -= n_discard_cache;
slot.truncated = true;
}
}
}
}
}
void server_context::add_sampled_tokens() {
for (auto& slot : slots) {
slot.released = false;
if (slot.state == SLOT_STATE_IDLE) {
continue;
}
// generate draft tokens in speculative decoding mode
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
const int n_draft_max_pre = slot.get_n_draft_max();
if (n_draft_max_pre > 0) {
if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
auto & params_spec = slot.params.speculative;
if (slot.has_mtp) {
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
llama_context * hs_ctx = mtp_ctx ? mtp_ctx : ctx;
if (!slot.mtp_hidden_state.empty()) {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
const int n_hidden = slot.mtp_hidden_state.size() / n_embd;
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd);
} else {
LOG_ERROR("MTP hidden state is empty during speculation", {});
const float* emb_neg1 = llama_get_embeddings_ith(ctx, -1);
if (emb_neg1) {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
slot.mtp_hidden_state.resize(n_embd);
memcpy(slot.mtp_hidden_state.data(), emb_neg1, n_embd * sizeof(float));
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data());
}
}
}
llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
const int n_draft_max = slot.get_n_draft_max();
if (draft.size() > (size_t)n_draft_max) {
if (slot.params.speculative.autotune) {
// expected near end-of-response when autotune shrinks n_max
SLT_DBG(slot, "draft size %d exceeds max %d, truncating\n", (int)draft.size(), n_draft_max);
} else {
SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int)draft.size(), n_draft_max);
}
draft.resize(n_draft_max);
}
// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
slot.cache_tokens.push_back(slot.sampled);
if (slot.params.speculative.n_min > (int)draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min);
// fallback to normal decoding
slot.i_batch = slot.i_batch_dft[0];
slot.drafted.clear();
slot.i_batch_dft.clear();
}
else {
// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();
// add all drafted tokens to the batch
for (size_t i = 0; i < draft.size(); i++) {
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, draft[i], slot.cache_tokens.pos_next(), { slot.id }, true);
slot.cache_tokens.push_back(draft[i]);
}
slot.drafted = std::move(draft);
}
}
else {
// no speculative decoding
slot.i_batch = batch.n_tokens;
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
slot.cache_tokens.push_back(slot.sampled);
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
(int)slot.n_ctx, (int)slot.cache_tokens.size(), (int)slot.truncated);
}
slot.n_past = slot.cache_tokens.n_tokens();
}
}
void server_context::create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base) {
if (params_base.do_checkpoint && params_base.ctx_checkpoints_interval > 0) {
auto pos = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
if (slot.checkpoint_pos + params_base.ctx_checkpoints_interval <= 1 + pos) {
bool created = create_checkpoint(slot);
if (created) {
slot.checkpoint_pos = pos;
}
}
}
}
void server_context::apply_checkpoint(server_slot & slot) {
llama_pos pos_next = slot.cache_tokens.pos_next(slot.n_past);
const auto pos_min_thold = std::max(0, pos_next - 1);
if (slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
if (pos_min > pos_min_thold) {
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int)slot.cache_tokens.size(), slot.id, pos_min);
// search for a context checkpoint
const auto it = std::find_if(
slot.server_cached_prompt.checkpoints.rbegin(),
slot.server_cached_prompt.checkpoints.rend(),
[&](const auto & cur) {
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
return cur.pos_min < pos_min_thold;
}
);
bool do_reset = it == slot.server_cached_prompt.checkpoints.rend();
if (!do_reset) {
// restore the context checkpoint
const int64_t t_start = ggml_time_us();
const size_t checkpoint_size = it->data.size();
const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
slot.n_past = std::min(slot.n_past, std::max(it->pos_min+1, it->pos_max));
slot.n_past = slot.cache_tokens.size_up_to_pos(slot.n_past-1);
slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt+1, it->pos_max_prompt));
slot.n_past_prompt = slot.prompt_tokens.size_up_to_pos(slot.n_past_prompt-1);
SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
}
}
if (do_reset) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
slot.n_past = 0;
slot.n_past_prompt = 0;
}
}
}
{
// erase any checkpoints with pos_min > pos_min_thold
for (auto it = slot.server_cached_prompt.checkpoints.begin(); it != slot.server_cached_prompt.checkpoints.end();) {
const auto & cur = *it;
if (cur.pos_min > pos_min_thold) {
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
it = slot.server_cached_prompt.checkpoints.erase(it);
} else {
++it;
}
}
}
}
bool server_context::create_checkpoint(server_slot & slot) {
bool do_checkpoint = !slot.image_just_processed;
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
// no need for empty or small checkpoints
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 16);
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max);
if (do_checkpoint) {
const int64_t t_start = ggml_time_us();
while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.ctx_checkpoints_n) {
// make room for the new checkpoint, if needed
const auto & cur = slot.server_cached_prompt.checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
}
const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.pos_min_prompt = */ pos_min + slot.n_past_offset,
/*.pos_max_prompt = */ pos_max + slot.n_past_offset ,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB, took %.2f ms)\n",
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024,
(ggml_time_us() - t_start) / 1000.0);
}
return do_checkpoint;
}
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto& slot : slots) {
// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) {
auto& prompt_tokens = slot.prompt_tokens;
// we haven't tokenized the prompt yet - do it now:
if (prompt_tokens.empty() || slot.n_prompt_tokens == 0) {
LOG_VERBOSE("tokenizing prompt", {
{"id_slot", slot.id},
{"id_task", slot.id_task}
});
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_generation = 0;
if (slot.infill) {
const bool add_bos = llama_should_add_bos_token(model);
bool suff_rm_leading_spc = true;
if (params_base.input_suffix.find_first_of(' ') == 0 && params_base.input_suffix.size() > 1) {
params_base.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
const int space_token = 29871; // TODO: this should not be hardcoded
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
suffix_tokens.erase(suffix_tokens.begin());
}
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
auto embd_inp = params_base.spm_infill ? suffix_tokens : prefix_tokens;
auto embd_end = params_base.spm_infill ? prefix_tokens : suffix_tokens;
if (add_bos) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
}
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
const llama_token middle_token = llama_token_middle(model);
if (middle_token >= 0) {
embd_inp.push_back(middle_token);
}
prompt_tokens = server_tokens(embd_inp, false);
}
else {
// prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
}
slot.n_past = 0;
slot.n_prompt_tokens = prompt_tokens.size();
LOG_VERBOSE("prompt tokenized", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_prompt_tokens", slot.n_prompt_tokens},
{"prompt_tokens", prompt_tokens.detokenize(ctx, true)},
});
// empty prompt passed -> release the slot and send empty response
if (prompt_tokens.empty()) {
LOG_INFO("empty prompt - releasing slot", {
{"id_slot", slot.id},
{"id_task", slot.id_task}
});
slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE;
send_final_response(slot);
slot.release();
slot.print_timings();
continue;
}
if (slot.embedding) {
// this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_ubatch) {
slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE;
slot.release();
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
continue;
}
}
else {
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
// context shift for prompt processing
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
if (!params_base.ctx_shift) {
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER);
slot.release();
continue;
}
context_shift_prompt(ctx, slot);
slot.truncated = true;
LOG_VERBOSE("input truncated", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_left", slot.n_ctx - slot.params.n_keep},
{"n_prompt_tokens", slot.n_prompt_tokens},
{"prompt_tokens", prompt_tokens.detokenize(ctx, true)},
});
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
#ifndef NDEBUG
// debug
common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false);
int32_t back = 1;
if (slot.cache_tokens.size() && slot.cache_tokens.size() > prefix.first + 20
&& prefix.second >= back && prefix.first >= back) {
LLAMA_LOG_INFO("After context shift :\n");
print_tokens(slot.prompt_tokens, slot.cache_tokens, prefix.second - back, prefix.first - back, 50);
}
#endif
}
else {
slot.n_discarded_prompt = 0;
}
common_sampler_reset(slot.ctx_sampling);
if (!slot.params.cache_prompt) {
slot.n_past_se = 0;
slot.ga_i = 0;
}
else {
GGML_ASSERT(slot.ga_n == 1);
// reuse any previously computed tokens that are common with the new prompt
common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, true); // string level match
common_prefix prefix_nonexact = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false);
auto n_past0 = slot.cache_tokens.get_common_prefix_exact(prompt_tokens); // token level match
LLAMA_LOG_INFO("======== Cache: cache_size = %d, n_past0 = %d, n_past1 = %d, n_past_prompt1 = %d, n_past2 = %d, n_past_prompt2 = %d\n", (int32_t)slot.cache_tokens.size(), (int32_t)n_past0, (int32_t)prefix.first, (int32_t)prefix.second, (int32_t)prefix_nonexact.first, (int32_t)prefix_nonexact.second);
int32_t size_threshold = 20;
if (prefix.first + size_threshold < prefix_nonexact.first) {
LLAMA_LOG_WARN("Common part contains missing or extra space and new line\n");
prefix = prefix_nonexact;
}
slot.n_past = prefix.first;
slot.n_past_prompt = prefix.second;
slot.n_past_offset = slot.n_past_prompt - slot.n_past;
if (slot.n_past != slot.n_past_prompt) {
LLAMA_LOG_INFO("Mistokenization found and handled successfully.\n");
}
if ((slot.n_past + size_threshold < slot.cache_tokens.size()))
{
LLAMA_LOG_WARN("Common part does not match fully\n");
int32_t back = 4;
if (prefix.second >= back && prefix.first >= back) {
print_tokens(slot.prompt_tokens, slot.cache_tokens, prefix.second - back, prefix.first - back, 30);
}
}
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
common_sampler_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
}
}
}
apply_checkpoint(slot);
if (slot.n_past_prompt == slot.n_prompt_tokens && slot.n_past_prompt > 0) {
// we have to evaluate at least 1 token to generate logits.
LOG_INFO("we have to evaluate at least 1 token to generate logits", {
{ "id_slot", slot.id },
{ "id_task", slot.id_task }
});
slot.n_past_prompt--;
slot.n_past--;
if (slot.ga_i > 0) {
slot.n_past_se--;
}
}
slot.n_prompt_tokens_processed = 0;
}
if (slot.embedding) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue;
}
}
// check that we are in the right batch_type, if not defer the slot
bool slot_type = slot.embedding ? 1 : 0;
if (batch_type == -1) {
batch_type = slot_type;
}
else if (batch_type != slot_type) {
continue;
}
// keep only the common part
// remove the non-common part from the cache
if (slot.n_past < 0)
{
slot.n_past = 0;
}
slot.cache_tokens.keep_first(slot.n_past);
int p0 = (int)system_tokens.size() + slot.n_past;
p0 = system_tokens.size() + slot.cache_tokens.pos_next();
if (!llama_kv_cache_seq_rm(ctx, slot.id, p0, -1)) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
p0 = (int)system_tokens.size();
if (p0 != 0) {
// copy over the system prompt when there is one
llama_kv_cache_seq_cp(ctx, 0, slot.id, -1, -1);
}
// there is no common part left (except for the system prompt)
slot.n_past = 0;
slot.n_past_se = 0;
slot.ga_i = 0;
// TODO: is the system prompt ever in the sampling context?
common_sampler_reset(slot.ctx_sampling);
}
LOG_INFO("kv cache rm [p0, end)", {
{ "id_slot", slot.id },
{ "id_task", slot.id_task },
{ "p0", p0 }
});
// check if we should process the image
if (slot.n_past_prompt < slot.n_prompt_tokens
&& slot.prompt_tokens[slot.n_past_prompt] == LLAMA_TOKEN_NULL) {
// process the image
size_t n_tokens_out = 0;
llama_pos p1 = slot.cache_tokens.pos_next() + slot.n_past_prompt - slot.n_past; // add offset to prompt
int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past_prompt, p1, slot.id, n_tokens_out);
if (res != 0) {
LLAMA_LOG_ERROR("failed to process image, res = %d\n", res);
slot.release();
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
continue;
}
// add the image chunk to cache
{
const auto& chunk = slot.prompt_tokens.find_chunk(slot.n_past_prompt);
slot.cache_tokens.push_back(chunk.get()); // copy
}
slot.n_past += n_tokens_out;
slot.n_past_prompt += n_tokens_out;
slot.n_prompt_tokens_processed += n_tokens_out;
slot.image_just_processed = true; // do not checkpoint right after an image chunk
}
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
int32_t ga_i = slot.ga_i;
int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w;
// add prompt tokens for processing in the current batch
// TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
while (slot.n_past_prompt < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
// get next token to process
llama_token cur_tok = slot.prompt_tokens[slot.n_past_prompt];
if (cur_tok == LLAMA_TOKEN_NULL) {
break; // end of text chunk
}
if (slot.ga_n != 1) {
while (slot_npast >= ga_i + ga_w) {
const int bd = (ga_w / ga_n) * (ga_n - 1);
slot_npast -= bd;
ga_i += ga_w / ga_n;
}
}
int p0 = system_tokens.size() + slot.cache_tokens.pos_next();
common_batch_add(batch, cur_tok, p0, { slot.id }, slot.need_embd());
slot.cache_tokens.push_back(cur_tok);
slot.n_prompt_tokens_processed++;
slot_npast++;
slot.n_past_prompt++;
slot.n_past++;
slot.image_just_processed = false;
if (params_base.do_checkpoint && slot.n_prompt_tokens - slot.n_past_prompt == params_base.ctx_checkpoints_tolerance) {
slot.do_checkpoint = true;
break;
}
}
LOG_VERBOSE("prompt processing progress", {
{"id_slot", slot.id},
{"n_past", slot.n_past},
{"n_ctx", n_ctx},
{"n_tokens", batch.n_tokens},
{"progress", (float)slot.n_prompt_tokens_processed / slot.n_prompt_tokens},
});
// entire prompt has been processed - start decoding new tokens
if (slot.n_past_prompt == slot.n_prompt_tokens) {
slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE;
GGML_ASSERT(batch.n_tokens > 0);
GGML_ASSERT((size_t)slot.n_prompt_tokens == slot.prompt_tokens.size());
common_sampler_reset(slot.ctx_sampling);
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
llama_token id = slot.prompt_tokens[i];
if (id != LLAMA_TOKEN_NULL) {
common_sampler_accept(slot.ctx_sampling, ctx, id, false);
}
}
// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
LOG_VERBOSE("prompt done", {
{"id_slot", slot.id},
{"n_past", slot.n_past},
{"n_ctx", n_ctx},
{"n_tokens", batch.n_tokens},
});
}
}
if (batch.n_tokens >= n_batch) {
break;
}
}
}
}
void server_context::extend_context(const int32_t n_tokens) {
for (auto& slot : slots) {
if (slot.ga_n != 1) {
// context extension via Self-Extend
// TODO: simplify and/or abstract this
while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
LOG_TEE("\n");
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
slot.n_past_se -= bd;
slot.ga_i += slot.ga_w / slot.ga_n;
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
}
slot.n_past_se += n_tokens;
}
}
}
void server_context::speculative_decoding_accept() {
for (auto& slot : slots) {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
continue;
}
size_t n_draft = slot.drafted.size();
apply_server_biases(slot);
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
if (slot.has_mtp) {
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
if (!ids.empty()) {
const float* emb = llama_get_embeddings(ctx);
if (emb) {
slot.mtp_hidden_state.resize(ids.size() * n_embd);
memcpy(slot.mtp_hidden_state.data(), emb, ids.size() * n_embd * sizeof(float));
}
} else {
const float* emb0 = llama_get_embeddings_ith(ctx, 0);
if (emb0) {
slot.mtp_hidden_state.resize(n_embd);
memcpy(slot.mtp_hidden_state.data(), emb0, n_embd * sizeof(float));
}
}
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data());
int32_t n_past_base = slot.n_past - (slot.drafted.size() + 1);
mtp_accept_tokens(mtp_target, ids, n_past_base, slot.id);
}
slot.i_batch_dft.clear();
slot.drafted.clear();
slot.n_past += ids.size();
slot.n_decoded += ids.size();
const int64_t t_current = ggml_time_us();
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
// inform the speculative decoding about the number of accepted tokens
common_speculative_accept(slot.spec, ids.size() - 1);
// rollback to the state before sampling the draft tokens
slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft);
// add accepted tokens to the prompt
slot.cache_tokens.insert({ ids.begin(), ids.end() - 1 });
slot.sampled = ids.back(); // last accepted token
slot.n_past = slot.cache_tokens.n_tokens();
// for hybrid models: if any drafts were rejected, restore recurrent state
const bool any_rejected = (ids.size() - 1) < n_draft;
if (any_rejected && slot.spec_ckpt_valid) {
llama_state_seq_set_data(ctx, slot.spec_ckpt_data.data(), slot.spec_ckpt_data.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
llama_kv_cache_seq_rm(ctx, slot.id, slot.spec_ckpt_n_past, -1);
// restore sampler state (RNG, grammar, prev tokens)
if (slot.spec_ckpt_sampler) {
common_sampler_clone(slot.spec_ckpt_sampler, slot.ctx_sampling);
common_sampler_free(slot.spec_ckpt_sampler);
slot.spec_ckpt_sampler = nullptr;
}
if (!ids.empty()) {
const int n_accepted = (int)ids.size();
llama_batch re_batch = llama_batch_init(n_accepted, 0, 1);
for (int j = 0; j < n_accepted; j++) {
const bool is_last = (j == n_accepted - 1);
common_batch_add(re_batch, ids[j], slot.spec_ckpt_n_past + j, { slot.id }, is_last);
}
if (slot.has_mtp) {
llama_set_embeddings(ctx, true);
}
const int ret = llama_decode(ctx, re_batch);
if (ret != 0) {
SLT_ERR(slot, "failed to re-decode accepted tokens after checkpoint restore: %d\n", ret);
}
if (slot.has_mtp) {
llama_set_embeddings(ctx, false);
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
const float * emb = llama_get_embeddings_ith(ctx, -1);
if (emb) {
slot.mtp_hidden_state.resize(n_embd);
memcpy(slot.mtp_hidden_state.data(), emb, n_embd * sizeof(float));
}
}
llama_batch_free(re_batch);
SLT_DBG(slot, "spec checkpoint restored: re-decoded %d accepted tokens (rejected %d)\n",
n_accepted, (int)(n_draft - (ids.size() - 1)));
}
slot.spec_ckpt_valid = false;
} else {
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
slot.spec_ckpt_valid = false;
// discard saved sampler on full acceptance
if (slot.spec_ckpt_sampler) {
common_sampler_free(slot.spec_ckpt_sampler);
slot.spec_ckpt_sampler = nullptr;
}
}
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
result.tok = ids[i];
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // set later
if (slot.sparams.n_probs > 0) {
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, i);
}
if (slot.n_buffer == 0 || !params_base.can_ban_phrases) {
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.cache_tokens.push_back(slot.sampled);
slot.n_past++;
send_final_response(slot);
release_slot_after_final_response(slot);
break;
}
} else {
buffer_and_check_string_ban(slot, result);
if (slot.task == nullptr) {
break;
}
}
common_sampler_review(slot.ctx_sampling, slot.token_buffer.size(), slot.rewind_status);
update_allowlist_state(slot);
}
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past);
LOG_VERBOSE("speculative decoding result", {
{"id_slot", slot.id},
{"accepted", (int)ids.size() - 1},
{"total", (int)slot.drafted.size()},
{"new_n_past", slot.n_past}
});
}
}
bool server_context::accept_special_token(const server_slot& slot, const llama_token token) {
return params_base.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end();
}
void server_context::release_slot_after_final_response(server_slot & slot) {
slot.print_timings();
if (params_base.do_checkpoint) {
create_checkpoint(slot);
}
slot.release();
slot.released = true;
metrics.on_prediction(slot);
}
void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) {
int count = 0;
bool released = false;
int32_t start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
for (auto& it : results) {
bool has_next = process_token(it, slot);
// Clean up positional bans for the token we just confirmed/sent
slot.positional_bans.erase(start_pos + count);
count++;
if (!has_next) {
if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
continue;
}
slot.cache_tokens.push_back(slot.sampled);
slot.n_past++;
send_final_response(slot);
release_slot_after_final_response(slot);
released = true;
break;
}
if (n > 0 && count >= n) {
break;
}
}
if (!released && slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
slot.cache_tokens.push_back(slot.sampled);
slot.n_past++;
send_final_response(slot);
release_slot_after_final_response(slot);
}
if (count > 0) {
slot.sampled = results[results.size()-1].tok;
results.erase(results.begin(), results.begin() + count);
}
}
inline int32_t check_ban_phrase(server_slot& slot) {
if (slot.token_buffer.empty()) return -1;
std::string string_buffer;
std::vector<size_t> token_offsets;
for (const auto& it : slot.token_buffer) {
token_offsets.push_back(string_buffer.size());
string_buffer += it.text_to_send;
}
size_t best_start = std::string::npos;
bool found = false;
std::string string_buffer_lower = string_lower(string_buffer);
// 1. Check strings
for (const auto& phrase : slot.ban_phrases) {
size_t start = string_buffer_lower.find(phrase);
if (start != std::string::npos) {
if (start < best_start) {
best_start = start;
found = true;
}
}
}
// 2. Check regex
for (const auto& pattern : slot.ban_regex) {
try {
std::regex re(pattern);
std::smatch match;
if (std::regex_search(string_buffer, match, re)) {
if (match.position() < best_start) {
best_start = match.position();
found = true;
}
}
} catch (...) { continue; }
}
// 3. Check regex case insensitive
for (const auto& pattern : slot.ban_regex_ci) {
try {
std::regex re(pattern, std::regex_constants::icase);
std::smatch match;
if (std::regex_search(string_buffer, match, re)) {
if (match.position() < best_start) {
best_start = match.position();
found = true;
}
}
} catch (...) { continue; }
}
if (found) {
int32_t token_idx = -1;
for (size_t i = 0; i < token_offsets.size(); ++i) {
size_t len = (i == token_offsets.size() - 1)
? string_buffer.size() - token_offsets[i]
: token_offsets[i+1] - token_offsets[i];
if (best_start >= token_offsets[i] && best_start < token_offsets[i] + len) {
token_idx = (int32_t)i;
break;
}
}
if (token_idx != -1) {
int32_t abs_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1 + token_idx;
return abs_pos;
}
}
return -1;
}
inline void rewind_context(server_slot& slot, int32_t ban_pos) {
slot.rewind_count++;
int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
int32_t n_keep_buffer = ban_pos - buffer_start_pos;
if (n_keep_buffer < 0) n_keep_buffer = 0;
if (slot.banned_n != 0) {
int32_t n = 0;
for (auto result = slot.token_buffer.begin() + n_keep_buffer; result != slot.token_buffer.end(); result++) {
llama_token banned_tok = result->tok;
if (n == 0) {
LLAMA_LOG_DEBUG("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n",
ban_pos, banned_tok, result->text_to_send.c_str());
}
slot.positional_bans[ban_pos].insert(banned_tok);
n++;
if (slot.banned_n > 0 && n == slot.banned_n) {
break;
}
}
}
int32_t n_rewind_total = (slot.n_past + 1) - ban_pos;
size_t n_keep_cache = 0;
if (ban_pos > 0) {
n_keep_cache = (size_t)(ban_pos - 1);
}
if (n_keep_cache > slot.cache_tokens.size()) {
n_keep_cache = slot.cache_tokens.size();
}
if (n_keep_cache < slot.cache_tokens.size()) {
slot.sampled = slot.cache_tokens[n_keep_cache];
} else {
slot.sampled = 0;
}
// Truncate cache
slot.cache_tokens.keep_first(n_keep_cache);
slot.n_past = slot.cache_tokens.n_tokens();
// Remove from KV cache
llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.n_past, -1);
// Truncate buffer
slot.token_buffer.resize(n_keep_buffer);
// Adjust decoded count
if (slot.saturate_predict) {
slot.n_decoded -= n_rewind_total;
if (slot.n_decoded < 0) slot.n_decoded = 0;
}
}
void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) {
slot.token_buffer.push_back(result);
bool next_token = has_next_token(result, slot);
// If buffer full or generation stopped, we might send tokens
bool buffer_full = slot.token_buffer.size() >= slot.n_buffer;
int32_t ban_pos = -1;
bool sent_results = false;
// Always reset logit bias to base before checking bans
slot.ctx_sampling->params.logit_bias = slot.logit_bias;
if (slot.ban_phrases.size() > 0 || slot.ban_regex.size() > 0 || slot.ban_regex_ci.size() > 0) {
ban_pos = check_ban_phrase(slot);
}
bool allow_rewind = true;
if (ban_pos >= 0) {
if (slot.rewind_count_max == -1) {
// Automatic / Heuristic logic
// Account for strings + regex + regex_ci
size_t total_bans = slot.ban_phrases.size() + slot.ban_regex.size() + slot.ban_regex_ci.size();
// Heuristic: Allow if under 20 OR under 2 * total_bans
// Conversely: Stop if >= 20 AND > 2 * total_bans
if (slot.rewind_count >= 20 && slot.rewind_count > 2 * total_bans) {
allow_rewind = false;
}
}
else if (slot.rewind_count_max > 0) {
// Strict limit logic
if (slot.rewind_count >= slot.rewind_count_max) {
allow_rewind = false;
}
}
// If slot.rewind_count_max == 0, allow_rewind remains true (Infinite)
}
if (ban_pos >= 0 && allow_rewind) {
rewind_context(slot, ban_pos);
slot.rewind_status = true;
}
else if (buffer_full || !next_token) {
slot.rewind_status = false;
slot.rewind_count = 0;
if (!next_token) {
// send all remaining tokens
send_token_results(slot.token_buffer, slot);
}
else {
// send 1 token from the front (FIFO)
send_token_results(slot.token_buffer, slot, 1);
}
}
else {
// buffer the result, wait for more tokens to validate string
slot.sampled = result.tok;
}
}
void server_context::update_allowlist_state(server_slot& slot) {
const auto& kws = slot.allow_kws;
auto& idx = slot.allow_idx;
if ((slot.allow_kw_delay > slot.n_decoded) || (idx >= kws.size())) {
return;
}
// search for keyword
auto kw = kws[idx];
auto pos = slot.generated_text.find(kw, std::max(0, slot.last_gentxt_size - (int32_t)kw.length() + 1));
while (pos != std::string::npos) {
if (++idx >= kws.size()) {
break;
}
kw = kws[idx];
pos = slot.generated_text.find(kw, pos + 1);
}
}
void server_context::process_batch_tokens(int32_t & n_batch) {
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
extend_context(n_tokens);
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
int user_cancel = -3;
if (ret == user_cancel) {
LLAMA_LOG_INFO("Decode process is cancelled by user.\n");
}
else {
// if you get here, it means the KV cache is full - try increasing it via the context size
LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
{"i", i},
{"n_batch", ret},
{"ret", ret},
});
}
for (auto& slot : slots) {
slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE;
slot.release();
if (ret != user_cancel) {
LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size());
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
}
}
break; // break loop of n_batch
}
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
i -= n_batch;
LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", {
{"i", i},
{"n_batch", n_batch},
{"ret", ret},
});
continue; // continue loop of n_batch
}
bool mtp_warmup_needed = false;
std::vector<float> batch_mtp_hidden_state;
if (params_base.has_mtp) {
for (auto& slot : slots) {
if ((slot.state == SLOT_STATE_PROCESSING && slot.n_decoded == 0) ||
(slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT)) {
bool has_tokens_for_slot = (batch_view.n_tokens > 0 && batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id);
if (has_tokens_for_slot) {
mtp_warmup_needed = true;
break;
}
}
}
if (mtp_warmup_needed) {
const float* emb = llama_get_embeddings(ctx);
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
const int n_toks = batch_view.n_tokens;
if (emb) {
batch_mtp_hidden_state.resize(n_toks * n_embd);
memcpy(batch_mtp_hidden_state.data(), emb, n_toks * n_embd * sizeof(float));
}
}
}
for (auto& slot : slots) {
bool is_active_slot = (slot.state == SLOT_STATE_PROCESSING);
if (!is_active_slot || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
// save checkpoint during prompt processing
if (slot.command == SLOT_COMMAND_LOAD_PROMPT) {
if (slot.do_checkpoint) {
create_checkpoint(slot);
} else {
create_checkpoint_at_interval(slot, params_base);
}
}
continue; // continue loop of slots
}
// prompt evaluated for embedding
if (slot.embedding) {
send_embedding(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue; // continue loop of slots
}
if (slot.n_decoded == 0 && slot.can_speculate()) {
common_speculative_begin(slot.spec, slot.cache_tokens.get_text_tokens());
}
if (slot.i_batch_dft.size() > 0) {
continue; // sample using speculative decoding
}
// RESTORE AND APPLY POSITIONAL BANS
slot.ctx_sampling->params.logit_bias = slot.logit_bias;
auto ban_it = slot.positional_bans.find(slot.n_past);
if (ban_it != slot.positional_bans.end()) {
for (llama_token tok : ban_it->second) {
slot.ctx_sampling->params.logit_bias[tok] += slot.ban_phrases_bias;
}
}
completion_token_output result;
const int tok_idx = slot.i_batch - i;
if (params_base.has_mtp && slot.n_decoded == 0) {
const float* emb_i = llama_get_embeddings_ith(ctx, tok_idx);
if (emb_i) {
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
slot.mtp_hidden_state.resize(n_embd);
memcpy(slot.mtp_hidden_state.data(), emb_i, n_embd * sizeof(float));
}
}
apply_server_biases(slot);
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, tok_idx);
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
slot.n_decoded += 1;
const int64_t t_current = ggml_time_us();
if (slot.n_decoded == 1) {
slot.t_start_generation = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
// create checkpoint after prompt processing ends
if (params_base.ctx_checkpoints_tolerance<=0 && params_base.do_checkpoint) {
create_checkpoint(slot);
}
}
// create checkpoint during generation
if (slot.n_decoded > 1) {
create_checkpoint_at_interval(slot, params_base);
}
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
result.tok = id;
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
if (slot.sparams.n_probs > 0) {
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
}
// no ban string for recurrent/hybrid model
if (slot.n_buffer == 0 || !params_base.can_ban_phrases) {
slot.token_buffer = { result };
send_token_results(slot.token_buffer, slot);
} else {
// buffer the result and check string ban.
// if ban, we need to go back, apply logit bias and regenerate
buffer_and_check_string_ban(slot, result);
}
common_sampler_review(slot.ctx_sampling, slot.token_buffer.size(), slot.rewind_status);
update_allowlist_state(slot);
slot.i_batch = -1;
}
if (mtp_warmup_needed && !batch_mtp_hidden_state.empty()) {
llama_context * mtp_ctx = nullptr;
for (auto & slot : slots) {
if (slot.spec && slot.has_mtp) {
llama_context * mc = common_speculative_get_mtp_ctx(slot.spec);
if (mc) { mtp_ctx = mc; break; }
}
}
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data());
mtp_update_kv_cache(mtp_target, batch_view, true);
}
// speculative decoding - main model sample and accept
speculative_decoding_accept();
}
}
void server_context::update_slots() {
if (system_need_update) {
system_prompt_update();
}
// release slots
release_slots();
// check if all slots are idle
if (slots_idle()) {
return;
}
{
LOG_VERBOSE("posting NEXT_RESPONSE", {});
server_task task;
task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
task.id_target = -1;
queue_tasks.post(std::move(task));
}
// apply context-shift if needed
// TODO: simplify and improve
context_shift();
// start populating the batch for this iteration
common_batch_clear(batch);
// frist, add sampled tokens from any ongoing sequences
add_sampled_tokens(); // Prepare batch for inference
// process in chunks of params.n_batch
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
// track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
// next, batch any pending prompts without exceeding n_batch
batch_pending_prompt(n_ubatch, n_batch, batch_type); // Prepare batch for prompt process
if (batch.n_tokens == 0) {
LOG_VERBOSE("no tokens to decode", {});
return;
}
LOG_VERBOSE("decoding batch", {
{"n_tokens", batch.n_tokens},
});
// make sure we're in the right embedding mode
llama_set_embeddings(ctx, batch_type == 1);
if (llama_model_is_hybrid(model)) {
for (auto & slot : slots) {
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
continue;
}
slot.spec_ckpt_n_past = slot.n_past - (int32_t)(slot.drafted.size() + 1);
const size_t ckpt_size = llama_state_seq_get_size(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
slot.spec_ckpt_data.resize(ckpt_size);
const size_t written = llama_state_seq_get_data(ctx, slot.spec_ckpt_data.data(), ckpt_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
slot.spec_ckpt_valid = (written > 0);
if (slot.spec_ckpt_valid) {
// save sampler state so we can restore RNG/grammar on rejection
if (slot.spec_ckpt_sampler) {
common_sampler_free(slot.spec_ckpt_sampler);
}
slot.spec_ckpt_sampler = common_sampler_init(model, slot.sparams);
common_sampler_clone(slot.ctx_sampling, slot.spec_ckpt_sampler);
SLT_DBG(slot, "spec checkpoint saved: %zu bytes, n_past_pre_spec=%d\n", written, slot.spec_ckpt_n_past);
} else {
SLT_WRN(slot, "%s", "failed to save spec checkpoint\n");
}
}
}
// process the created batch of tokens
process_batch_tokens(n_batch); // Decode with batch
LOG_VERBOSE("run slots completed", {});
}
json server_context::model_meta() const {
return json{
{"vocab_type", llama_vocab_type(llama_model_get_vocab(model))},
{"n_vocab", llama_n_vocab(model)},
{"n_ctx_train", llama_n_ctx_train(model)},
{"n_embd", llama_model_n_embd(model)},
{"n_params", llama_model_n_params(model)},
{"size", llama_model_size(model)},
};
}