Files
ik_llama.cpp/examples/server/server-context.h
firecoperana 2a633c4357 server: exclude thinking tokens when finding the slot (#1079)
refactor find slot

enable by default

Fix load prompt

rename variables

Co-authored-by: firecoperana <firecoperana>
2025-12-22 09:46:45 +01:00

327 lines
9.3 KiB
C++

#include "server-task.h"
#include "server-queue.h"
#include "speculative.h"
#include "json-schema-to-grammar.h"
#include <nlohmann/json_fwd.hpp>
#include <cstddef>
#include <memory>
enum slot_state {
SLOT_STATE_IDLE,
SLOT_STATE_PROCESSING,
};
enum slot_command {
SLOT_COMMAND_NONE,
SLOT_COMMAND_LOAD_PROMPT,
SLOT_COMMAND_RELEASE,
};
struct slot_params {
bool stream = true;
bool include_usage = false;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
thinking_tokens think_tokens;
std::vector<std::string> antiprompt;
bool timings_per_token = false;
bool post_sampling_probs = false;
json input_prefix;
json input_suffix;
// speculative decoding parameters
struct {
int n_max = 16; // max drafted tokens
int n_min = 0; // min drafted tokens to accept
float p_min = 0.75f; // min probability required to accept a token in the draft
} speculative;
// OAI-compat fields
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_syntax oaicompat_chat_syntax;
};
struct server_slot {
int id;
int id_task = -1;
int id_multi = -1;
struct slot_params params;
slot_state state = SLOT_STATE_IDLE;
slot_command command = SLOT_COMMAND_NONE;
llama_context* ctx = nullptr;
// used to determine the slot that has been used the longest
int64_t t_last_used = -1;
std::unique_ptr<const server_task> task;
// generation props
int32_t n_ctx = 0; // context size per slot
int32_t n_past = 0;
int32_t n_past_prompt = 0;
int32_t n_decoded = 0;
int32_t n_remaining = -1;
int32_t n_discarded_prompt = 0;
int32_t n_kept_prompt = 0;
int32_t i_batch = -1;
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
int32_t n_prompt_tokens = 0;
int32_t n_prompt_tokens_processed = 0;
json prompt; // can be either a string, array of strings or array of token ids
// when a task is submitted, we first tokenize the prompt and store it here
server_tokens prompt_tokens;
server_tokens cache_tokens;
std::string generated_text;
std::vector<completion_token_output> generated_token_probs;
common_chat_msg chat_msg;
bool infill = false;
bool embedding = false;
bool has_next_token = true;
bool truncated = false;
bool stopped_eos = false;
bool stopped_word = false;
bool stopped_limit = false;
bool oaicompat = false;
std::string oaicompat_model;
std::string stopping_word;
stop_type stop;
server_prompt server_cached_prompt;
void prompt_save(server_prompt_cache& prompt_cache) const;
void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens);
// sampling
llama_token sampled;
struct llama_sampling_params sparams;
llama_sampling_context* ctx_sampling = nullptr;
json json_schema;
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
std::vector<std::string> generated_tool_call_ids;
int32_t ga_i = 0; // group-attention state
int32_t ga_n = 1; // group-attention factor
int32_t ga_w = 512; // group-attention width
// multimodal
mtmd_context* mctx = nullptr;
// speculative decoding
struct llama_speculative* spec = nullptr;
llama_context* ctx_dft = nullptr;
llama_batch batch_spec = {};
// speculative decoding stats
int32_t n_draft_total = 0; // Total draft tokens generated
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
int32_t n_past_se = 0; // self-extend
// stats
size_t n_sent_text = 0; // number of sent text character
size_t n_sent_token_probs = 0;
int64_t t_start_process_prompt;
int64_t t_start_generation;
double t_prompt_processing; // ms
double t_token_generation; // ms
void reset();
bool has_budget(gpt_params& global_params);
bool available() const;
bool is_processing() const;
void add_token_string(const completion_token_output& token);
void release();
json get_formated_timings() const;
result_timings get_timings() const;
const common_chat_msg& update_chat_msg(std::vector<common_chat_msg_diff>& diffs);
size_t find_stopping_strings(const std::string& text, const size_t last_token_size, bool is_full_stop);
void print_timings() const;
};
struct server_metrics {
int64_t t_start = 0;
uint64_t n_prompt_tokens_processed_total = 0;
uint64_t t_prompt_processing_total = 0;
uint64_t n_tokens_predicted_total = 0;
uint64_t t_tokens_generation_total = 0;
uint64_t n_prompt_tokens_processed = 0;
uint64_t t_prompt_processing = 0;
uint64_t n_tokens_predicted = 0;
uint64_t t_tokens_generation = 0;
void init();
void on_prompt_eval(const server_slot& slot);
void on_prediction(const server_slot& slot);
void reset_bucket();
};
struct server_context {
llama_model* model = nullptr;
llama_context* ctx = nullptr;
std::vector<llama_lora_adapter_container> lora_adapters;
gpt_params params;
llama_batch batch;
bool clean_kv_cache = true;
bool add_bos_token = true;
bool has_eos_token = false;
// multimodal
mtmd_context* mctx = nullptr;
// For speculative decoding
llama_model* model_draft = nullptr;
llama_context* ctx_draft = nullptr;
llama_context_params cparams_dft;
int32_t n_ctx; // total context for all clients / slots
// system prompt
bool system_need_update = false;
std::string system_prompt;
std::vector<llama_token> system_tokens;
// slots / clients
std::vector<server_slot> slots;
json default_generation_settings_for_props;
server_queue queue_tasks;
server_response queue_results;
std::unique_ptr<server_prompt_cache> prompt_cache;
server_metrics metrics;
common_chat_templates_ptr chat_templates;
oaicompat_parser_options oai_parser_opt;
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;
int32_t cache_ram_n_min = 0;
float cache_ram_similarity = 0.5f;
~server_context();
bool load_model(const gpt_params& params_);
void init();
std::vector<llama_token> tokenize(const json& json_prompt, bool add_special) const;
server_slot* get_slot_by_id(int id);
float calculate_slot_f_keep(const server_slot& slot, llama_context* ctx, const server_tokens& a, const server_tokens& b);
std::pair<common_prefix, float> calculate_slot_similarity(const server_slot& slot, llama_context* ctx, const server_tokens& a, const server_tokens& b);
void copy_data_to_cached_prompt(const server_tokens& tokens, server_slot& slot);
server_slot* get_available_slot(const server_task& task);
bool launch_slot_with_task(server_slot& slot, server_task& task);
void kv_cache_clear();
void system_prompt_update();
bool system_prompt_set(const std::string& sys_prompt);
bool process_token(completion_token_output& result, server_slot& slot);
void populate_token_probs(const server_slot& slot, completion_token_output& result, bool post_sampling, bool special, int idx);
json get_formated_generation(const server_slot& slot) const;
void send_error(const server_task& task, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER);
void send_error(const server_slot& slot, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER);
void send_error(const int id_task, const int id_multi, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER);
// if multimodal is enabled, send an error and return false
bool ensure_no_mtmd(const int id_task);
void send_partial_response(server_slot& slot, completion_token_output tkn);
void send_final_response(server_slot& slot);
void send_embedding(const server_slot& slot, const llama_batch& batch);
void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs);
void request_cancel(int id_task);
void split_multiprompt_task(int id_multi, server_task& multiprompt_task);
void process_single_task(server_task&& task);
void on_finish_multitask(const server_task_multi& multitask);
void print_tokens(const server_tokens& prompt, const server_tokens& cache, size_t start1 = 0, size_t start2 = 0, size_t length = 10);
// discard tokens in kv cache and cached tokens
void discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard);
// convert keep first few and discard next tokens in a to b
void context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep,
int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact = false);
// handle context shift for prompt
void context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact = false);
void update_slots();
json model_meta() const;
};