#include "server-task.h" #include "server-queue.h" #include "speculative.h" #include "json-schema-to-grammar.h" #include #include #include enum slot_state { SLOT_STATE_IDLE, SLOT_STATE_PROCESSING, }; enum slot_command { SLOT_COMMAND_NONE, SLOT_COMMAND_LOAD_PROMPT, SLOT_COMMAND_RELEASE, }; struct slot_params { bool stream = true; bool include_usage = false; bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half int32_t n_predict = -1; // new tokens to predict thinking_tokens think_tokens; std::vector antiprompt; bool timings_per_token = false; bool post_sampling_probs = false; json input_prefix; json input_suffix; // speculative decoding parameters struct { int n_max = 16; // max drafted tokens int n_min = 0; // min drafted tokens to accept float p_min = 0.75f; // min probability required to accept a token in the draft } speculative; // OAI-compat fields oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; common_chat_syntax oaicompat_chat_syntax; }; struct server_slot { int id; int id_task = -1; int id_multi = -1; struct slot_params params; slot_state state = SLOT_STATE_IDLE; slot_command command = SLOT_COMMAND_NONE; llama_context* ctx = nullptr; // used to determine the slot that has been used the longest int64_t t_last_used = -1; std::unique_ptr task; // generation props int32_t n_ctx = 0; // context size per slot int32_t n_past = 0; int32_t n_past_prompt = 0; int32_t n_decoded = 0; int32_t n_remaining = -1; int32_t n_discarded_prompt = 0; int32_t n_kept_prompt = 0; int32_t i_batch = -1; int32_t n_predict = -1; // TODO: disambiguate from params.n_predict int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens_processed = 0; json prompt; // can be either a string, array of strings or array of token ids // when a task is submitted, we first tokenize the prompt and store it here server_tokens prompt_tokens; server_tokens cache_tokens; std::string generated_text; std::vector generated_token_probs; common_chat_msg chat_msg; bool infill = false; bool embedding = false; bool has_next_token = true; bool truncated = false; bool stopped_eos = false; bool stopped_word = false; bool stopped_limit = false; bool oaicompat = false; std::string oaicompat_model; std::string stopping_word; stop_type stop; server_prompt server_cached_prompt; void prompt_save(server_prompt_cache& prompt_cache) const; void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens); // sampling llama_token sampled; struct llama_sampling_params sparams; llama_sampling_context* ctx_sampling = nullptr; json json_schema; common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; std::vector generated_tool_call_ids; int32_t ga_i = 0; // group-attention state int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width // multimodal mtmd_context* mctx = nullptr; // speculative decoding struct llama_speculative* spec = nullptr; llama_context* ctx_dft = nullptr; llama_batch batch_spec = {}; // speculative decoding stats int32_t n_draft_total = 0; // Total draft tokens generated int32_t n_draft_accepted = 0; // Draft tokens actually accepted int32_t n_past_se = 0; // self-extend // stats size_t n_sent_text = 0; // number of sent text character size_t n_sent_token_probs = 0; int64_t t_start_process_prompt; int64_t t_start_generation; double t_prompt_processing; // ms double t_token_generation; // ms void reset(); bool has_budget(gpt_params& global_params); bool available() const; bool is_processing() const; void add_token_string(const completion_token_output& token); void release(); json get_formated_timings() const; result_timings get_timings() const; const common_chat_msg& update_chat_msg(std::vector& diffs); size_t find_stopping_strings(const std::string& text, const size_t last_token_size, bool is_full_stop); void print_timings() const; }; struct server_metrics { int64_t t_start = 0; uint64_t n_prompt_tokens_processed_total = 0; uint64_t t_prompt_processing_total = 0; uint64_t n_tokens_predicted_total = 0; uint64_t t_tokens_generation_total = 0; uint64_t n_prompt_tokens_processed = 0; uint64_t t_prompt_processing = 0; uint64_t n_tokens_predicted = 0; uint64_t t_tokens_generation = 0; void init(); void on_prompt_eval(const server_slot& slot); void on_prediction(const server_slot& slot); void reset_bucket(); }; struct server_context { llama_model* model = nullptr; llama_context* ctx = nullptr; std::vector lora_adapters; gpt_params params; llama_batch batch; bool clean_kv_cache = true; bool add_bos_token = true; bool has_eos_token = false; // multimodal mtmd_context* mctx = nullptr; // For speculative decoding llama_model* model_draft = nullptr; llama_context* ctx_draft = nullptr; llama_context_params cparams_dft; int32_t n_ctx; // total context for all clients / slots // system prompt bool system_need_update = false; std::string system_prompt; std::vector system_tokens; // slots / clients std::vector slots; json default_generation_settings_for_props; server_queue queue_tasks; server_response queue_results; std::unique_ptr prompt_cache; server_metrics metrics; common_chat_templates_ptr chat_templates; oaicompat_parser_options oai_parser_opt; // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; int32_t cache_ram_n_min = 0; float cache_ram_similarity = 0.5f; ~server_context(); bool load_model(const gpt_params& params_); void init(); std::vector tokenize(const json& json_prompt, bool add_special) const; server_slot* get_slot_by_id(int id); float calculate_slot_f_keep(const server_slot& slot, llama_context* ctx, const server_tokens& a, const server_tokens& b); std::pair calculate_slot_similarity(const server_slot& slot, llama_context* ctx, const server_tokens& a, const server_tokens& b); void copy_data_to_cached_prompt(const server_tokens& tokens, server_slot& slot); server_slot* get_available_slot(const server_task& task); bool launch_slot_with_task(server_slot& slot, server_task& task); void kv_cache_clear(); void system_prompt_update(); bool system_prompt_set(const std::string& sys_prompt); bool process_token(completion_token_output& result, server_slot& slot); void populate_token_probs(const server_slot& slot, completion_token_output& result, bool post_sampling, bool special, int idx); json get_formated_generation(const server_slot& slot) const; void send_error(const server_task& task, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER); void send_error(const server_slot& slot, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER); void send_error(const int id_task, const int id_multi, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER); // if multimodal is enabled, send an error and return false bool ensure_no_mtmd(const int id_task); void send_partial_response(server_slot& slot, completion_token_output tkn); void send_final_response(server_slot& slot); void send_embedding(const server_slot& slot, const llama_batch& batch); void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs); void request_cancel(int id_task); void split_multiprompt_task(int id_multi, server_task& multiprompt_task); void process_single_task(server_task&& task); void on_finish_multitask(const server_task_multi& multitask); void print_tokens(const server_tokens& prompt, const server_tokens& cache, size_t start1 = 0, size_t start2 = 0, size_t length = 10); // discard tokens in kv cache and cached tokens void discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard); // convert keep first few and discard next tokens in a to b void context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep, int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact = false); // handle context shift for prompt void context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact = false); void update_slots(); json model_meta() const; };