mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-12 08:56:28 +00:00
* Autoparser - complete refactoring of parser architecture Autoparser: add optional argument reshuffle capability Autoparser: True streaming (#20177) * Relax atomicity constraint for nicer, more pleasent, True Streaming parsing * Whitespace * Remove redundant atomics Revert to OAI-compatible args (#20213) * Revert to OAI-compatible args * Apply workaround::func_args_not_string Fix structured outputs (#20223) * Fix structured outputs * Update common/chat-auto-parser-generator.cpp Co-authored-by: Aldehir Rojas <hello@alde.dev> --------- Co-authored-by: Aldehir Rojas <hello@alde.dev> Fix compile bug (#20203) * Fix compile bug * Update common/chat-auto-parser-helpers.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> # Conflicts: # common/chat-auto-parser-helpers.cpp common : gracefully handle incomplete output (#20191) * common : handle incomplete UTF-8 at end of input in PEG parser * cont : if reached end prematurely, emit needs_more_input to propagate partial output * cont: refactor peg parse context to add lenient flag * cont : remove partial flag, keep lenient flag PEG parser for LFM2 (#20251) * PEG parser for LFM2 * Simplify using python_value() common: map developer role to system (#20215) * Map developer role to system * Simplify common: consolidate PEG string parsers (#20263) * common : consolidate PEG string parsers * cont : fix json_string_content() examples : fix empty items in json_schema_to_grammar.py [no ci] (#19968) * Fix logic for retrieving schema items in `json_schema_to_grammar.py` If `schema['items']` is `{}` and `prefixItems not in schema', as `{}` is Falsy, the original code here will raise an error. I think if `schema['items']` is `{}`, them items should just be `{}` * Apply suggestion from @CISC Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Add tests for arrays with empty items Add two unit tests to `tests/test-json-schema-to-grammar.cpp` that validate handling of arrays when 'items' is an empty schema and when 'prefixItems' is present alongside an empty 'items'. Both tests expect the same generated grammar, ensuring the JSON Schema->grammar conversion treats an empty 'items' schema (and the presence of 'prefixItems') correctly and covering this edge case. --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Reduce level of content parser warning message to avoid log spam on non-debug verbosity (#20347) do not return if template parse failed add arg to enable parallel tool call common : fix incorrect uses of stoul (#20313) # Conflicts: # common/arg.cpp # src/llama-grammar.cpp examples : fix empty items in json_schema_to_grammar.py [no ci] (#19968) * Fix logic for retrieving schema items in `json_schema_to_grammar.py` If `schema['items']` is `{}` and `prefixItems not in schema', as `{}` is Falsy, the original code here will raise an error. I think if `schema['items']` is `{}`, them items should just be `{}` * Apply suggestion from @CISC Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Add tests for arrays with empty items Add two unit tests to `tests/test-json-schema-to-grammar.cpp` that validate handling of arrays when 'items' is an empty schema and when 'prefixItems' is present alongside an empty 'items'. Both tests expect the same generated grammar, ensuring the JSON Schema->grammar conversion treats an empty 'items' schema (and the presence of 'prefixItems') correctly and covering this edge case. --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Add support for MiroThinker with new jinja template common/parser: handle reasoning budget (#20297) * v1 * Finished! * Handlie cli * Reasoning sampler * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Less explosive terminology :) * Add utf-8 case and tests * common : migrate reasoning budget sampler to common * cont : clean up * cont : expose state and allow passing as initial state * cont : remove unused imports * cont : update state machine doc string --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Alde Rojas <hello@alde.dev> common/parser: use nlohmann::ordered_json to preserve parameter order (#20385) common/parser: add GigaChatV3/3.1 models support (#19931) Co-authored-by: Mishusha <pmv26021975@gmail.com> common/parser: gracefully handle undetected tool parser, print error message. (#20286) fix: prevent nullptr dereference (#20552) common : fix iterator::end() dereference (#20445) # Conflicts: # common/regex-partial.cpp jinja : add capability check for object args (#20612) common/parser: add `--skip-chat-parsing` to force a pure content parser. (#20289) * Add `--force-pure-content` to force a pure content parser. * Update common/arg.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> common : rework gpt-oss parser (#20393) * common : rework gpt-oss parser * cont : fix gpt-oss tests * cont : add structured output test * cont : rename final to final_msg common : fix gpt-oss content removal (#20745) common/parser: add proper reasoning tag prefill reading (#20424) * Implement proper prefill extraction * Refactor cli parameters, update docs, move reasoning budget sampler part to common/reasoning-budget.cpp * Update tools/server/server-task.cpp * refactor: move grammars to variant, remove grammar_external, handle exception internally * Make code less C++y Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> chat : handle tool calls with no required args in TAG_WITH_TAGGED format (#20764) * chat : handle tool calls with no required args in TAG_WITH_TAGGED format * Update tests/test-chat.cpp [no ci] Co-authored-by: Aldehir Rojas <hello@alde.dev> --------- Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com> Co-authored-by: Aldehir Rojas <hello@alde.dev> common/parser : fix out_of_range crash in throw path (#20424 regression) (#20777) * chat : fix out_of_range crash in throw path (#20424 regression) #20424 introduced effective_input = generation_prompt + input, but the throw path uses input.substr(result.end) where result.end is a position within effective_input. Every thinking model with a non-empty generation_prompt crashes with std::out_of_range instead of the intended error message. Test crashes on unpatched master, passes with fix: cmake -B build -DLLAMA_BUILD_TESTS=ON -DLLAMA_BUILD_TOOLS=OFF cmake --build build --target test-chat ./build/bin/test-chat * Update test-chat.cpp * Update test-chat.cpp * Update test-chat.cpp --------- Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com> jinja : fix heap OOB read in value equality comparison (#20782) Address GHSA-q9j6-4hhc-rq9p and GHSA-2q4c-9gq5-5vfp. The three-iterator overload of std::equal in value_array_t::equivalent() and value_object_t::equivalent() reads past the end of the shorter container when comparing arrays or objects of different lengths. Use the four-iterator overload (C++14) which checks both range lengths. Found-by: Pwno common : fix typo in debug log ('extracft' -> 'extract') (#20807) common/parser: fix nasty bug causing subtle corruption of generation prompt (#20825) jinja : refactor token advancement (#20864) * refactor token advancement * exercise sub-expressions common/autoparser : detect reasoning markers when enable_thinking changes system prompt (#20859) common : replace wrap_for_generation with a prefix convenience function and fix gpt-oss (#20912) jinja: fix macro with kwargs (#20960) * jinja: fix macro with kwargs * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * fix newline problem --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> common : inhibit lazy grammar sampler while reasoning is active (#20970) * common : inhibit grammar while reasoning budget is active * cont : update force_pos in accept * cont : fix tests * cont : tweak should apply logic * cont : return early not using grammar sampler * Add tests * cont : prevent backend sampling when reasoning budget enabled * cont : fix typo --------- Co-authored-by: Piotr Wilkin <piotr.wilkin@syndatis.com> # Conflicts: # common/reasoning-budget.h # common/sampling.cpp # tools/cli/cli.cpp # tools/server/server-common.cpp # tools/server/server-task.cpp common/parser: fix reasoning whitespace bugs + extra parser tests (#21085) * fix whitespace reasoning issues + add reconstruction tests * Proper fix * fix Nemotron autoparser test expectations to include newline in marker common : add reasoning_format = none support to gpt-oss (#21094) common/json-schema: fix: handle non-capturing groups (?:...) in JSON schema pattern converter (#21124) The regex-to-grammar converter in _visit_pattern() crashes with SIGSEGV when a JSON schema "pattern" field contains a non-capturing group (?:...). Root cause: when the parser sees '(' followed by '?', it pushes a warning but does not advance past '?:'. The recursive transform() call then interprets '?' as a quantifier and calls seq.back() on an empty vector, causing undefined behavior. This commonly occurs when serving OpenAI-compatible tool calls from clients that include complex regex patterns in their JSON schemas (e.g., date validation patterns like ^(?:(?:\d\d[2468][048]|...)-02-29|...)$). The fix: - Skip '?:' after '(' to treat non-capturing groups as regular groups - For unsupported syntax (?=, ?!, etc.), skip to matching ')' safely, handling escaped characters to avoid miscounting parenthesis depth - Adjust the ')' unbalanced-parentheses check using direct char comparisons instead of substr - Add test cases for non-capturing groups (C++ only, as the JS/Python implementations do not yet support this syntax) common/parser: fix handling of tool definition with missing properties key (#21128) jinja : handle empty expressions correctly (#20913) * Reject empty computed member expressions before returning slices[0] from parse_member_expression_arguments(). * Treat empty computed member expressions with Jinja2 undefined semantics Treat empty computed member expressions like `a[]` as undefined instead of raising a parser error, to match Jinja2 behavior. - return a noop expression for empty computed member arguments - return undefined when a computed member key evaluates to undefined - add Jinja tests covering `a[]|default('fallback')` and `a[] is undefined` * Handle undefined computed member properties Move undefined-property handling to the common member access path, and add a test covering `a[undefined] is undefined`. * Use default undefined value in member access Initialize val and then return it when property is undefined. Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * empty statement parses to blank_expression instead of noop_statement --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> common : gpt-oss handle builtin and unsolicited tool calls (#21213) fix: tool call parsing for LFM2 and LFM2.5 models (#21242) * fix: tool call parsing for LFM2 and LFM2.5 models' * refactor: add test / break out lfm2 and lfm2.5 parsing logic # Conflicts: # common/chat.cpp Relax prefill parser to allow space. (#21240) * Relax prefill parser to allow space. * Move changes from prefix() to parser generation * Only allow spaces if we're not having a pure content parser next common : add commentary rules for gpt-oss-20b (#21286) add reasoning budget model, mtmd: fix gguf conversion for audio/vision mmproj (#21309) * fix gguf conversion for audio/vision mmproj * fix test # Conflicts: # convert_hf_to_gguf.py # examples/eval-callback/eval-callback.cpp # examples/mtmd/CMakeLists.txt # examples/mtmd/clip-impl.h # examples/mtmd/mtmd.cpp # gguf-py/gguf/constants.py # gguf-py/gguf/gguf_writer.py # gguf-py/gguf/tensor_mapping.py # src/CMakeLists.txt # src/llama-arch.cpp # src/llama-arch.h # src/llama-model.cpp # src/llama-model.h # src/llama-vocab.cpp # src/models/models.h # tests/test-llama-archs.cpp # tools/mtmd/clip-graph.h # tools/mtmd/clip-model.h # tools/mtmd/clip.cpp # tools/mtmd/models/models.h fix: gemma 4 template (#21326) chat : avoid including json in chat.h (#21306) jinja: coerce input for string-specific filters (#21370) common : fix tool call type detection for nullable and enum schemas (#21327) * common : fix tool call type detection for nullable and enum schemas * common, tests : fix grammar delegation for nullable/enum schemas and add tests Fix enum type inference to scan all enum values (not just index 0) so schemas like {"enum": [0, "celsius"]} correctly detect string type. Fix schema_delegates in peg-parser to handle nullable type arrays (["string", "null"]) and typeless enum schemas in raw mode, allowing the tagged parser to use raw text instead of JSON-formatted strings. Add test cases for Qwen3-Coder (TAG_WITH_TAGGED format): - nullable string ["string", "null"] - nullable string with null first ["null", "string"] - nullable integer ["integer", "null"] - enum without explicit type key common/parser: fix call ID detection (Mistral parser mostly) + atomicity for tag-json parsers (#21230) * Fix call ID detection (Mistral parser mostly) + atomicity for tag-json parsers * Rename * Update common/chat-auto-parser-generator.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> common : add gemma 4 specialized parser (#21418) * common : add gemma4 dedicated parser * cont : add '<|tool_response>' as eog * cont : emit JSON from Gemma4 tool call AST * cont : more fixes * cont : refactor convert function * cont : refine rules and mapping * cont : add more tests * cont : clean up * cont : remove autoparser gemma4 implementation * cont : more cleanup * cont : rename gemma4.jinja to match the others * cont : add custom template to support interleaved thinking * cont : preserve reasoning in model turns * cont : fix initializer error * cont : fix unused vars * cont : fix accidental static * cont : fix specialized_template signature * fix extra semicolon * remove debug line and extra space [no ci] fix reasoning budget parser: fix MiniMax handling (#21573) jinja : support ensure_ascii=true, string repetition and int/float self-filtering (#21623) * feat: jinja engine improvements for reka-edge Port three Jinja engine improvements needed for the reka-edge model: 1. Python-style string repetition ("ab" * 3 → "ababab") 2. ensure_ascii=true support for tojson filter (escapes non-ASCII to \uXXXX) 3. int() builtin on value_int_t (identity, needed for Reka Edge template) * fix: escape invalid utf8 bytes when ensure_ascii=true The json_ensure_ascii_preserving_format function does not correctly handle an edge case where if UTF-8 parsing fails, it adds the non-ascii character back to the output as a raw byte. This commit fixes that by adding the unicode standard replacement character \\ufffd to the output instead. This is the standard behavior for various programming languages like Python, Rust, Go, etc. * chore: address PR comments 1. Add todo comment for supporting string repetition for array/tuples 2. Add support for float identity operation 3. Move invalid ascii test case to test_fuzzing * chore: accept suggestion for common/jinja/value.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> common : simplify autoparser tagged parser rules (#21216) * common : simplify autoparser tagged parser rules * cont : remove upper limit on optional args * cont : revert changes to parsing at the end * cont : undo arbitrary ordering of optional args * cont : fix uninitialized required parameters * revert to simplify merge * re-apply patches * restore flexible optional arg ordering tests common : fix ambiguous grammar rule in gemma4 (#21661) * common : fix ambiguous grammar rule in gemma4 * cont : fix missing comma... common : enable reasoning budget sampler for gemma4 (#21697) * fix: enable reasoning budget sampler for gemma4 Add thinking_start_tag and thinking_end_tag to common_chat_params_init_gemma4(). Without these, the reasoning budget sampler never activates for gemma4. Make the newline after "thought" optional in the PEG parser to handle budget=0 (sampler forces end tag before the newline). Add test case for empty thinking block. Fixes #21487 * use p.space() instead of p.optional(p.literal("\n")) in gemma4 thought parser common : better align to the updated official gemma4 template (#21704) fix: Fix broken structured output when using $refs in json_schema (#21699) chat: dedicated DeepSeek v3.2 parser + "official" template (#21785) Hide render_message_to_json warning common/gemma4 : handle parsing edge cases (#21760) common: skip reasoning budget sampler when no budget is requested (#21870) * common: skip reasoning budget sampler when no budget is requested After I added thinking_start_tag / thinking_end_tag for gemma4 in #21697, the reasoning budget sampler gets unconditionally created even when no budget is configured (the default -1). The same applies to kimi_k2, lfm2, lfm2_5, and ministral_3 which also set these tags. The budget gets converted to INT_MAX, so the sampler never actually forces any tokens but still runs per-token checks (start tag matching in IDLE state, token-to-piece conversion + UTF-8 checks in COUNTING state). More importantly, the mere existence of the sampler (non-null rbudget) disables backend sampling. Backend sampling lets the GPU select tokens directly, avoiding a full logits transfer from GPU to CPU every token. This could explain the 30% speed regression reported in #21784 (98 t/s to 70 t/s on Vulkan). So I added a reasoning_budget_tokens >= 0 check to the sampler creation condition. When the budget is unlimited, the sampler is not created, backend sampling stays enabled, and no per-token overhead is added. When a budget is explicitly set (0, 128, 1024, etc.), the sampler is created and works as before. * common: preserve rbudget when grammar is lazy Following up on the review feedback on #21870: keep the reasoning budget sampler when grammar_lazy is true, so the thinking-block grammar suppression from #20970 still works when tools are in use. This way, we only skip the sampler when both no budget is set AND grammar is not lazy. autoparser: support case of JSON_NATIVE with per-call markers (test case: Reka-Edge) (#21892) * fix grammar * fix add sampled token --------- Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com> Co-authored-by: firecoperana <firecoperana>
4317 lines
175 KiB
C++
4317 lines
175 KiB
C++
#include "server-context.h"
|
|
#include "server-common.h"
|
|
#include "server-task.h"
|
|
#include "server-queue.h"
|
|
|
|
#include "common.h"
|
|
#include "llama.h"
|
|
#include "log.h"
|
|
#include "sampling.h"
|
|
#include "speculative.h"
|
|
#include "mtmd.h"
|
|
#include "mtmd-helper.h"
|
|
|
|
#include <fstream>
|
|
#include <iostream>
|
|
#include <regex>
|
|
#include <exception>
|
|
|
|
static void log_text(const gpt_params & params_base, const std::string & text) {
|
|
if (params_base.minilog) {
|
|
LOG_TEE("%s\n", text.c_str());
|
|
}
|
|
}
|
|
|
|
server_context::~server_context() {
|
|
if (ctx) {
|
|
llama_free(ctx);
|
|
ctx = nullptr;
|
|
}
|
|
|
|
if (model) {
|
|
llama_free_model(model);
|
|
model = nullptr;
|
|
}
|
|
// Free multimodal
|
|
mtmd_free(mctx);
|
|
// Free draft model and context if they exist
|
|
if (ctx_draft) {
|
|
llama_free(ctx_draft);
|
|
ctx_draft = nullptr;
|
|
}
|
|
if (model_draft) {
|
|
llama_free_model(model_draft);
|
|
model_draft = nullptr;
|
|
}
|
|
|
|
// Clear any sampling context
|
|
for (server_slot& slot : slots) {
|
|
if (slot.ctx_sampling != nullptr) {
|
|
common_sampler_free(slot.ctx_sampling);
|
|
}
|
|
if (slot.ctx_dft) {
|
|
llama_free(slot.ctx_dft);
|
|
}
|
|
common_speculative_free(slot.spec);
|
|
llama_batch_free(slot.batch_spec);
|
|
}
|
|
|
|
llama_batch_free(batch);
|
|
}
|
|
|
|
bool server_context::load_model(const gpt_params& params_) {
|
|
params_base = params_;
|
|
|
|
llama_init_result llama_init = llama_init_from_gpt_params(params_base);
|
|
|
|
model = llama_init.model;
|
|
ctx = llama_init.context;
|
|
lora_adapters = llama_init.lora_adapters;
|
|
|
|
if (model == nullptr) {
|
|
LOG_ERROR("unable to load model", { {"model", params_base.model} });
|
|
return false;
|
|
}
|
|
|
|
n_ctx = llama_n_ctx(ctx);
|
|
|
|
add_bos_token = llama_should_add_bos_token(model);
|
|
has_eos_token = llama_add_eos_token(model) != 1;
|
|
|
|
bool has_draft_model = !params_base.speculative.model.empty() || !params_base.speculative.params.empty();
|
|
std::string& mmproj_path = params_base.mmproj.path;
|
|
if (!mmproj_path.empty()) {
|
|
mtmd_context_params mparams = mtmd_context_params_default();
|
|
mparams.use_gpu = params_base.mmproj_use_gpu;
|
|
mparams.print_timings = false;
|
|
mparams.n_threads = params_base.n_threads;
|
|
mparams.flash_attn_type = params_base.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
|
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
|
mparams.image_min_tokens = params_base.image_min_tokens;
|
|
mparams.image_max_tokens = params_base.image_max_tokens;
|
|
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
|
if (mctx == nullptr) {
|
|
LOG_ERROR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
|
|
return false;
|
|
}
|
|
LOG_INFO("loaded multimodal model, '%s'\n", mmproj_path.c_str());
|
|
|
|
//if (params.n_cache_reuse) {
|
|
// params_base.n_cache_reuse = 0;
|
|
// SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
|
|
//}
|
|
|
|
if (has_draft_model) {
|
|
LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal");
|
|
return false;
|
|
}
|
|
if (params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE) {
|
|
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
|
SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled");
|
|
}
|
|
}
|
|
// Load draft model for speculative decoding if specified
|
|
if (has_draft_model) {
|
|
|
|
if (llama_model_has_recurrent(model)) {
|
|
LLAMA_LOG_WARN("\n=======================================================================\n");
|
|
LLAMA_LOG_WARN(" Speculative decodong is not suported for recurrent/hybrid models\n");
|
|
LLAMA_LOG_WARN(" --> bailing out\n");
|
|
LLAMA_LOG_WARN("========================================================================\n\n");
|
|
GGML_ABORT("Fatal error");
|
|
}
|
|
|
|
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
|
|
|
|
gpt_params params_dft;
|
|
params_dft.devices = params_base.speculative.devices;
|
|
params_dft.model = params_base.speculative.model;
|
|
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
|
params_dft.rpc_servers = params_base.rpc_servers;
|
|
params_dft.cache_type_k = params_base.speculative.cache_type_k.empty() ? params_base.cache_type_k : params_base.speculative.cache_type_k;
|
|
params_dft.cache_type_v = params_base.speculative.cache_type_v.empty() ? params_base.cache_type_v : params_base.speculative.cache_type_v;
|
|
params_dft.flash_attn = params_base.flash_attn;
|
|
if (!params_base.speculative.params.empty()) {
|
|
auto [argc, argv] = parse_command_line("llama-server " + params_base.speculative.params);
|
|
if (!gpt_params_parse(argc, argv, params_dft)) {
|
|
gpt_params_print_usage(argc, argv, params_dft);
|
|
free_command_line(argc, argv);
|
|
return false;
|
|
};
|
|
free_command_line(argc, argv);
|
|
}
|
|
LOG_INFO("", { {"model", params_dft.model} });
|
|
if (params_dft.n_ctx == 0) {
|
|
params_dft.n_ctx = params_base.speculative.n_ctx;
|
|
}
|
|
params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx;
|
|
params_dft.n_parallel = 1;
|
|
params_dft.n_batch = params_dft.n_ctx;
|
|
|
|
params_base.speculative.mparams_dft.path = params_dft.model; //
|
|
|
|
llama_model_params mparams_dft = common_model_params_to_llama(params_dft);
|
|
|
|
llama_model * model_dft = llama_model_load_from_file(params_dft.model.c_str(), mparams_dft);
|
|
if (model_dft == nullptr) {
|
|
LOG_ERROR("failed to load draft model", { {"model", params_base.speculative.model} });
|
|
return false;
|
|
}
|
|
|
|
cparams_dft = common_context_params_to_llama(params_dft);
|
|
|
|
params_base.speculative.model_dft = model_dft;
|
|
params_base.speculative.cparams_dft = cparams_dft;
|
|
|
|
}
|
|
else if (params_base.has_mtp && llama_model_n_nextn_layer(model) == 0) {
|
|
LOG_WARNING("WARNING: -mtp flag provided, but model has 0 NextN layers. MTP will be disabled.\n", {});
|
|
params_base.has_mtp = false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void server_context::init() {
|
|
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
|
|
|
|
LOG_INFO("initializing slots", { {"n_slots", params_base.n_parallel} });
|
|
|
|
for (int i = 0; i < params_base.n_parallel; i++) {
|
|
server_slot slot;
|
|
|
|
slot.id = i;
|
|
slot.ctx = ctx;
|
|
slot.n_ctx = n_ctx_slot;
|
|
slot.n_predict = params_base.n_predict;
|
|
slot.mctx = mctx;
|
|
slot.cache_tokens.has_mtmd = mctx != nullptr;
|
|
slot.params.think_tokens = params_base.think_tokens;
|
|
if (params_base.think_tokens.exclude) {
|
|
SRV_WRN("Exclude reasoning tokens when selecting slot based on similarity: start: %s, end: %s\nuse `--reasoning-tokens none` to disable.\n", params_base.think_tokens.begin.c_str(), params_base.think_tokens.end.c_str() );
|
|
}
|
|
else {
|
|
SRV_WRN("%s", "Include reasoning tokens when selecting slot based on similarity\nuse `--reasoning-tokens auto` to exclude reasoning tokens.\n");
|
|
}
|
|
LOG_INFO("new slot", {
|
|
{"id_slot", slot.id},
|
|
{"n_ctx_slot", slot.n_ctx}
|
|
});
|
|
|
|
const int ga_n = params_base.grp_attn_n;
|
|
const int ga_w = params_base.grp_attn_w;
|
|
|
|
if (ga_n != 1) {
|
|
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
|
|
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
|
|
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
|
|
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
|
|
|
|
LOG_INFO("slot self-extend", {
|
|
{"id_slot", slot.id},
|
|
{"ga_n", ga_n},
|
|
{"ga_w", ga_w}
|
|
});
|
|
}
|
|
|
|
slot.ga_i = 0;
|
|
slot.ga_n = ga_n;
|
|
slot.ga_w = ga_w;
|
|
|
|
slot.sparams = params_base.sparams;
|
|
|
|
if (params_base.has_mtp) {
|
|
if (llama_model_n_nextn_layer(model) > 0) {
|
|
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
|
|
params_base.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
|
|
|
params_base.speculative.cparams_dft = common_context_params_to_llama(params_base);
|
|
params_base.speculative.cparams_dft.mtp = true;
|
|
params_base.speculative.cparams_dft.mtp_op_type = MTP_OP_WARMUP;
|
|
params_base.speculative.cparams_dft.embeddings = true;
|
|
|
|
slot.has_mtp = true;
|
|
slot.params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
|
|
slot.params.speculative.n_min = 0;
|
|
slot.params.speculative.cparams_dft = params_base.speculative.cparams_dft;
|
|
|
|
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
|
|
SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens);
|
|
|
|
SRV_INF("%s\n", "MTP needs embeddings on decode, enabling");
|
|
llama_set_embeddings(ctx, true);
|
|
}
|
|
else {
|
|
SRV_WRN("%s\n", "MTP enabled via flag, but model has 0 NextN layers. Disabling speculative.");
|
|
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
|
slot.has_mtp = false;
|
|
}
|
|
}
|
|
|
|
bool can_spec = true;
|
|
if (!params_base.dry_run) {
|
|
can_spec = common_speculative_is_compat(ctx);
|
|
}
|
|
if (!can_spec) {
|
|
SRV_WRN("%s", "speculative decoding not supported by this context\n");
|
|
}
|
|
// try speculative decoding
|
|
if (can_spec) {
|
|
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
|
|
if (slot.spec) {
|
|
if (mctx) {
|
|
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
|
|
return;
|
|
}
|
|
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
|
|
} else {
|
|
if (slot.has_mtp) {
|
|
SRV_ERR("%s", "failed to initialize MTP speculative context, aborting\n");
|
|
GGML_ABORT("MTP context creation failed");
|
|
} else {
|
|
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
|
|
}
|
|
}
|
|
}
|
|
|
|
slot.reset();
|
|
|
|
slots.push_back(std::move(slot));
|
|
}
|
|
|
|
default_generation_settings_for_props = get_formated_generation(slots.front());
|
|
default_generation_settings_for_props["seed"] = -1;
|
|
|
|
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
|
|
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
|
|
{
|
|
const int32_t n_batch = llama_n_batch(ctx);
|
|
|
|
// only a single seq_id per token is needed
|
|
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
|
}
|
|
|
|
metrics.init();
|
|
|
|
if (params_base.cache_ram_mib != 0) {
|
|
if (params_base.cache_ram_mib < 0) {
|
|
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %s\n", "no limit");
|
|
}
|
|
else {
|
|
LLAMA_LOG_INFO("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
|
|
}
|
|
LLAMA_LOG_INFO("%s", "use `--cache-ram 0` to disable the prompt cache\n");
|
|
// only apply ram size limit. No token limit for now.
|
|
prompt_cache = std::make_unique<server_prompt_cache>(ctx, params_base.cache_ram_mib, 0);
|
|
}
|
|
else {
|
|
LLAMA_LOG_INFO("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
|
|
}
|
|
|
|
// populate chat template params
|
|
{
|
|
common_chat_templates_ptr chat_templates;
|
|
|
|
try {
|
|
chat_templates = common_chat_templates_init(model, params_base.chat_template);
|
|
|
|
LOG_INF("%s: chat template, example_format: '%s'\n", __func__,
|
|
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
|
|
|
|
}
|
|
catch (const std::exception & e) {
|
|
SRV_ERR("%s: chat template parsing error: %s\n", __func__, e.what());
|
|
SRV_ERR("%s: please consider enabling jinja via --jinja, or use a custom chat template via --chat-template\n", __func__);
|
|
SRV_ERR("%s: for example: --chat-template chatml\n", __func__);
|
|
}
|
|
|
|
// thinking is enabled if:
|
|
// 1. It's not explicitly disabled via --reasoning off
|
|
// 2. The chat template supports it
|
|
const bool template_supports_thinking = params_base.use_jinja && common_chat_templates_support_enable_thinking(chat_templates.get());
|
|
const bool enable_thinking = params_base.enable_reasoning != 0 && template_supports_thinking;
|
|
|
|
chat_params = {
|
|
/* use_jinja */ params_base.use_jinja,
|
|
/* prefill_assistant */ params_base.prefill_assistant,
|
|
/* reasoning_format */ params_base.reasoning_format,
|
|
/* chat_template_kwargs */ params_base.default_template_kwargs,
|
|
/* tmpls */ std::move(chat_templates),
|
|
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
|
|
/* allow_audio */ mctx ? mtmd_support_audio(mctx) : false,
|
|
/* enable_thinking */ enable_thinking,
|
|
/* parallel_tool_calls */ params_base.parallel_tool_calls,
|
|
/* reasoning_budget */ params_base.reasoning_budget,
|
|
/* reasoning_budget_msg */ params_base.reasoning_budget_message,
|
|
/* force_pure_content */ params_base.force_pure_content_parser
|
|
// /* media_path */ params_base.media_path,
|
|
};
|
|
}
|
|
|
|
}
|
|
|
|
|
|
void server_slot::prompt_save(server_prompt_cache& prompt_cache) const {
|
|
assert(server_cached_prompt.data.size() == 0);
|
|
|
|
const size_t cur_size = llama_state_seq_get_size(ctx, id, 0);
|
|
|
|
LLAMA_LOG_INFO(" - saving prompt with length %d, total state size = %.3f MiB\n",
|
|
(int)server_cached_prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
|
|
|
|
auto* cur = prompt_cache.alloc(server_cached_prompt, cur_size);
|
|
if (cur == nullptr) {
|
|
return;
|
|
}
|
|
|
|
llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id, 0);
|
|
}
|
|
|
|
void server_slot::prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) {
|
|
bool res = prompt_cache.load(server_cached_prompt, tokens, ctx, id);
|
|
if (!res) {
|
|
LLAMA_LOG_INFO("failed to load prompt from cache\n");
|
|
}
|
|
}
|
|
|
|
void server_slot::reset() {
|
|
n_prompt_tokens = 0;
|
|
last_gentxt_size = 0;
|
|
generated_text = "";
|
|
truncated = false;
|
|
stopped_eos = false;
|
|
stopped_word = false;
|
|
stopped_limit = false;
|
|
stopping_word = "";
|
|
n_past = 0;
|
|
n_past_prompt = 0;
|
|
n_sent_text = 0;
|
|
drafted.clear();
|
|
i_batch_dft.clear();
|
|
n_sent_token_probs = 0;
|
|
infill = false;
|
|
ga_i = 0;
|
|
n_past_se = 0;
|
|
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
|
|
logit_bias.clear();
|
|
token_buffer.clear();
|
|
rewind_count = 0;
|
|
n_buffer = 0;
|
|
rewind_status = false;
|
|
|
|
generated_token_probs.clear();
|
|
checkpoint_pos = 0;
|
|
image_just_processed = false;
|
|
do_checkpoint = false;
|
|
|
|
positional_bans.clear();
|
|
ban_phrases.clear();
|
|
ban_regex.clear();
|
|
ban_regex_ci.clear();
|
|
|
|
allow_ruless.clear();
|
|
allow_pieces.clear();
|
|
allow_kws.clear();
|
|
allow_kw_delay = 0;
|
|
allow_idx = 0;
|
|
|
|
// Reset speculative decoding stats
|
|
n_draft_total = 0;
|
|
n_draft_accepted = 0;
|
|
chat_msg = {};
|
|
json_schema = json();
|
|
generated_tool_call_ids.clear();
|
|
|
|
anthropic_thinking_block_started = false;
|
|
anthropic_text_block_started = false;
|
|
|
|
oai_resp_thinking_block_started = false;
|
|
oai_resp_text_block_started = false;
|
|
oai_resp_id.clear();
|
|
oai_resp_reasoning_id.clear();
|
|
oai_resp_message_id.clear();
|
|
oai_resp_fc_id.clear();
|
|
|
|
task.reset();
|
|
}
|
|
|
|
bool server_slot::need_embd() const {
|
|
return embedding || has_mtp;
|
|
}
|
|
|
|
bool server_slot::has_budget(gpt_params& global_params) {
|
|
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
|
return true; // limitless
|
|
}
|
|
|
|
n_remaining = -1;
|
|
|
|
if (params.n_predict != -1) {
|
|
n_remaining = params.n_predict - n_decoded;
|
|
}
|
|
else if (global_params.n_predict != -1) {
|
|
n_remaining = global_params.n_predict - n_decoded;
|
|
}
|
|
|
|
return n_remaining > 0; // no budget
|
|
}
|
|
|
|
bool server_slot::available() const {
|
|
return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE;
|
|
}
|
|
|
|
bool server_slot::is_processing() const {
|
|
return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING;
|
|
}
|
|
|
|
void server_slot::add_token_string(const completion_token_output& token) {
|
|
if (command == SLOT_COMMAND_RELEASE) {
|
|
return;
|
|
}
|
|
generated_token_probs.push_back(token);
|
|
}
|
|
|
|
bool server_slot::can_speculate() const {
|
|
return (!!spec || has_mtp);
|
|
}
|
|
|
|
int server_slot::get_n_draft_max() const {
|
|
if (!can_speculate()) {
|
|
return 0;
|
|
}
|
|
|
|
// determine the max draft that fits the current slot state
|
|
int n_draft_max = params.speculative.n_max;
|
|
|
|
// note: slot.prompt is not yet expanded with the `id` token sampled above
|
|
// also, need to leave space for 1 extra token to allow context shifts
|
|
n_draft_max = std::min(n_draft_max, n_ctx - n_past - 2);
|
|
|
|
if (n_remaining > 0) {
|
|
n_draft_max = std::min(n_draft_max, n_remaining - 1);
|
|
}
|
|
|
|
SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
|
|
|
|
if (n_draft_max < params.speculative.n_min) {
|
|
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, params.speculative.n_min);
|
|
n_draft_max = 0;
|
|
}
|
|
return n_draft_max;
|
|
}
|
|
|
|
void server_slot::release() {
|
|
if (state == SLOT_STATE_PROCESSING) {
|
|
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
|
command = SLOT_COMMAND_RELEASE;
|
|
state = SLOT_STATE_IDLE;
|
|
task.reset();
|
|
llama_decode_reset();
|
|
}
|
|
|
|
}
|
|
|
|
|
|
json server_slot::get_formated_timings() const {
|
|
return json{
|
|
{"prompt_n", n_prompt_tokens_processed},
|
|
{"prompt_ms", t_prompt_processing},
|
|
{"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed},
|
|
{"prompt_per_second", 1e3 / t_prompt_processing * n_prompt_tokens_processed},
|
|
|
|
{"predicted_n", n_decoded},
|
|
{"predicted_ms", t_token_generation},
|
|
{"predicted_per_token_ms", t_token_generation / n_decoded},
|
|
{"predicted_per_second", 1e3 / t_token_generation * n_decoded},
|
|
|
|
{"n_ctx", n_ctx},
|
|
{"n_past", n_past},
|
|
};
|
|
}
|
|
|
|
result_timings server_slot::get_timings() const {
|
|
result_timings timings;
|
|
timings.prompt_n = n_prompt_tokens_processed;
|
|
timings.prompt_ms = t_prompt_processing;
|
|
timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
|
|
timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
|
|
|
timings.predicted_n = n_decoded;
|
|
timings.predicted_ms = t_token_generation;
|
|
timings.predicted_per_token_ms = t_token_generation / n_decoded;
|
|
timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
|
|
|
|
timings.n_ctx = n_ctx;
|
|
timings.n_past = n_past;
|
|
|
|
|
|
// Add speculative metrics
|
|
if (n_draft_total > 0) {
|
|
timings.draft_n = n_draft_total;
|
|
timings.draft_n_accepted = n_draft_accepted;
|
|
}
|
|
|
|
return timings;
|
|
}
|
|
|
|
const common_chat_msg& server_slot::update_chat_msg(bool is_partial, std::vector<common_chat_msg_diff>& diffs,
|
|
bool filter_tool_calls) {
|
|
auto msg_prv_copy = chat_msg;
|
|
auto new_msg = common_chat_parse(
|
|
generated_text,
|
|
/* is_partial= */ stop != STOP_TYPE_EOS,
|
|
params.chat_parser_params);
|
|
if (!new_msg.empty()) {
|
|
//new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
|
|
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
|
|
chat_msg = new_msg;
|
|
auto all_diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, chat_msg);
|
|
|
|
if (!filter_tool_calls) {
|
|
diffs = std::move(all_diffs);
|
|
} else {
|
|
for (auto & d : all_diffs) {
|
|
// If this is a new type of delta, flush all currently pending tool call names
|
|
for (size_t i = 0; i < chat_msg.tool_calls.size(); ++i) {
|
|
if (sent_tool_call_names.count(i) || chat_msg.tool_calls[i].name.empty()) {
|
|
continue;
|
|
}
|
|
if (d.tool_call_index != i || !d.tool_call_delta.arguments.empty()) {
|
|
common_chat_msg_diff header;
|
|
header.tool_call_index = i;
|
|
header.tool_call_delta.id = chat_msg.tool_calls[i].id;
|
|
header.tool_call_delta.name = chat_msg.tool_calls[i].name;
|
|
diffs.push_back(std::move(header));
|
|
sent_tool_call_names.insert(i);
|
|
}
|
|
}
|
|
|
|
if (d.tool_call_index == std::string::npos) {
|
|
diffs.push_back(std::move(d));
|
|
} else {
|
|
size_t i = d.tool_call_index;
|
|
if (sent_tool_call_names.count(i)) {
|
|
if (!d.tool_call_delta.arguments.empty()) {
|
|
d.tool_call_delta.name = "";
|
|
d.tool_call_delta.id = "";
|
|
diffs.push_back(std::move(d));
|
|
}
|
|
} else {
|
|
// Not sent yet.
|
|
if (!d.tool_call_delta.arguments.empty() || !is_partial) {
|
|
d.tool_call_delta.name = chat_msg.tool_calls[i].name;
|
|
d.tool_call_delta.id = chat_msg.tool_calls[i].id;
|
|
diffs.push_back(std::move(d));
|
|
sent_tool_call_names.insert(i);
|
|
} else {
|
|
// Suppress
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// Final check at EOF
|
|
if (!is_partial) {
|
|
for (size_t i = 0; i < chat_msg.tool_calls.size(); ++i) {
|
|
if (!sent_tool_call_names.count(i) && !chat_msg.tool_calls[i].name.empty()) {
|
|
common_chat_msg_diff header;
|
|
header.tool_call_index = i;
|
|
header.tool_call_delta.id = chat_msg.tool_calls[i].id;
|
|
header.tool_call_delta.name = chat_msg.tool_calls[i].name;
|
|
diffs.push_back(std::move(header));
|
|
sent_tool_call_names.insert(i);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return chat_msg;
|
|
}
|
|
|
|
|
|
size_t server_slot::find_stopping_strings(const std::string& text, const size_t last_token_size, bool is_full_stop) {
|
|
size_t stop_pos = std::string::npos;
|
|
|
|
for (const std::string& word : params.antiprompt) {
|
|
size_t pos;
|
|
|
|
if (is_full_stop) {
|
|
const size_t tmp = word.size() + last_token_size;
|
|
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
|
|
|
|
pos = text.find(word, from_pos);
|
|
}
|
|
else {
|
|
pos = string_find_partial_stop(text, word);
|
|
}
|
|
|
|
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
|
|
if (is_full_stop) {
|
|
stopped_word = true;
|
|
stopping_word = word;
|
|
has_next_token = false;
|
|
}
|
|
stop_pos = pos;
|
|
}
|
|
}
|
|
|
|
return stop_pos;
|
|
}
|
|
|
|
void server_slot::print_timings() const {
|
|
char buffer[512];
|
|
double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
|
|
double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
|
|
|
double t_gen = t_token_generation / n_decoded;
|
|
double n_gen_second = 1e3 / t_token_generation * n_decoded;
|
|
|
|
SLT_INF(*this,
|
|
"\n"
|
|
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
|
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
|
" total time = %10.2f ms / %5d tokens\n",
|
|
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
|
|
t_token_generation, n_decoded, t_gen, n_gen_second,
|
|
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
|
|
|
|
if (n_draft_total > 0) {
|
|
const float draft_ratio = (float)n_draft_accepted / n_draft_total;
|
|
SLT_CNT(*this,
|
|
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
|
|
draft_ratio, n_draft_accepted, n_draft_total
|
|
);
|
|
}
|
|
common_speculative_print_stats(spec, n_gen_second, n_decoded, n_past,
|
|
const_cast<common_params_speculative *>(¶ms.speculative));
|
|
}
|
|
|
|
void server_metrics::init() {
|
|
t_start = ggml_time_us();
|
|
}
|
|
|
|
void server_metrics::on_prompt_eval(const server_slot& slot) {
|
|
n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed;
|
|
n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
|
|
t_prompt_processing += slot.t_prompt_processing;
|
|
t_prompt_processing_total += slot.t_prompt_processing;
|
|
}
|
|
|
|
void server_metrics::on_prediction(const server_slot& slot) {
|
|
n_tokens_predicted_total += slot.n_decoded;
|
|
n_tokens_predicted += slot.n_decoded;
|
|
t_tokens_generation += slot.t_token_generation;
|
|
t_tokens_generation_total += slot.t_token_generation;
|
|
}
|
|
|
|
void server_metrics::reset_bucket() {
|
|
n_prompt_tokens_processed = 0;
|
|
t_prompt_processing = 0;
|
|
n_tokens_predicted = 0;
|
|
t_tokens_generation = 0;
|
|
}
|
|
|
|
std::vector<llama_token> server_context::tokenize(const json& json_prompt, bool add_special) const {
|
|
// TODO: currently, we tokenize using special tokens by default
|
|
// this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
|
|
// but it's better compared to completely ignoring ChatML and other chat templates
|
|
const bool TMP_FORCE_SPECIAL = true;
|
|
|
|
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
|
// or the first element of the json_prompt array is a string.
|
|
std::vector<llama_token> prompt_tokens;
|
|
|
|
if (json_prompt.is_array()) {
|
|
bool first = true;
|
|
for (const auto& p : json_prompt) {
|
|
if (p.is_string()) {
|
|
auto s = p.template get<std::string>();
|
|
|
|
std::vector<llama_token> p;
|
|
if (first) {
|
|
p = ::common_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
|
|
first = false;
|
|
}
|
|
else {
|
|
p = ::common_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
|
|
}
|
|
|
|
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
|
}
|
|
else {
|
|
if (first) {
|
|
first = false;
|
|
}
|
|
|
|
prompt_tokens.push_back(p.template get<llama_token>());
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
auto s = json_prompt.template get<std::string>();
|
|
prompt_tokens = ::common_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
|
|
}
|
|
|
|
return prompt_tokens;
|
|
}
|
|
|
|
server_slot* server_context::get_slot_by_id(int id) {
|
|
for (server_slot& slot : slots) {
|
|
if (slot.id == id) {
|
|
return &slot;
|
|
}
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
float server_context::calculate_slot_f_keep(const server_slot & slot, llama_context * ctx,const server_tokens & a, const server_tokens & b) {
|
|
float f_keep = 0.0f;
|
|
if (!a.empty()) {
|
|
if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && b.size() >= slot.n_ctx) {
|
|
f_keep = a.get_cached_tokens_similarity(slot.ctx, b, slot.params.n_keep + add_bos_token, slot.n_discarded_prompt);
|
|
}
|
|
else {
|
|
f_keep = a.get_cached_tokens_similarity(slot.ctx, b, 0, 0);
|
|
}
|
|
}
|
|
return f_keep;
|
|
}
|
|
|
|
std::pair<common_prefix, float> server_context::calculate_slot_similarity(const server_slot& slot, llama_context* ctx, const server_tokens& a, const server_tokens& b) {
|
|
std::pair<common_prefix, float> sim;
|
|
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
|
common_prefix lcp_len = a.get_common_prefix(slot.ctx, b);
|
|
// fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length
|
|
float sim_cur = a.get_tokens_similarity(slot.ctx, b, 0, 0);
|
|
// handle context shift
|
|
if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && b.size() >= slot.n_ctx) {
|
|
float sim_cur_ctx_shift = a.get_tokens_similarity(slot.ctx, b, slot.n_kept_prompt, slot.n_discarded_prompt);
|
|
if (sim_cur_ctx_shift > sim_cur) {
|
|
sim_cur = sim_cur_ctx_shift;
|
|
}
|
|
}
|
|
sim.first = lcp_len;
|
|
sim.second = sim_cur;
|
|
return sim;
|
|
}
|
|
|
|
void server_context::copy_data_to_cached_prompt(const server_tokens & tokens, server_slot & slot) {
|
|
slot.server_cached_prompt.tokens = tokens.clone(); // copy cache tokens
|
|
slot.server_cached_prompt.n_discarded_prompt = slot.n_discarded_prompt;
|
|
slot.server_cached_prompt.n_kept_prompt = slot.n_kept_prompt;
|
|
slot.server_cached_prompt.think_tokens = slot.params.think_tokens;
|
|
}
|
|
|
|
server_slot* server_context::get_available_slot(const server_task& task) {
|
|
server_slot* ret = nullptr;
|
|
bool update_cache = false;
|
|
|
|
// find the slot that has at least n% prompt similarity
|
|
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
|
|
int max_lcp_len = 0;
|
|
float sim_best = 0;
|
|
|
|
for (server_slot& slot : slots) {
|
|
// skip the slot if it is not available
|
|
if (!slot.available()) {
|
|
continue;
|
|
}
|
|
auto& cache_tokens = slot.cache_tokens;
|
|
// skip the slot if it does not contains prompt
|
|
if (cache_tokens.empty()) {
|
|
continue;
|
|
}
|
|
std::pair<common_prefix, float> sim;
|
|
if (slot.params.think_tokens.exclude) {
|
|
server_tokens cache_tokens_exclude_think = slot.cache_tokens.get_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
|
|
server_tokens prompt_tokens_exclude_think = task.tokens.get_tokens_exclude_think(slot.ctx, slot.params.think_tokens);
|
|
sim = calculate_slot_similarity(slot, ctx, cache_tokens_exclude_think, prompt_tokens_exclude_think);
|
|
}
|
|
else {
|
|
sim = calculate_slot_similarity(slot, ctx, cache_tokens, task.tokens);
|
|
}
|
|
common_prefix lcp_len = sim.first;
|
|
float sim_cur = sim.second;
|
|
|
|
// select the current slot if the criteria match
|
|
if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
|
|
sim_best = sim_cur;
|
|
max_lcp_len = lcp_len.first;
|
|
ret = &slot;
|
|
}
|
|
}
|
|
if (ret != nullptr) {
|
|
LOG_VERBOSE("selected slot by lcp similarity", {
|
|
{"id_slot", ret->id},
|
|
{"max_lcp_len", max_lcp_len},
|
|
{"similarity", sim_best},
|
|
});
|
|
}
|
|
}
|
|
|
|
// find the slot that has been least recently used
|
|
if (ret == nullptr) {
|
|
int64_t t_last = ggml_time_us();
|
|
for (server_slot& slot : slots) {
|
|
// skip the slot if it is not available
|
|
if (!slot.available()) {
|
|
continue;
|
|
}
|
|
// select the current slot if the criteria match
|
|
if (slot.t_last_used < t_last) {
|
|
t_last = slot.t_last_used;
|
|
ret = &slot;
|
|
}
|
|
}
|
|
|
|
if (ret != nullptr) {
|
|
LOG_VERBOSE("selected slot by lru", {
|
|
{"id_slot", ret->id},
|
|
{"t_last", t_last},
|
|
});
|
|
}
|
|
}
|
|
if (ret) {
|
|
auto& tokens = ret->cache_tokens;
|
|
float f_keep = 0;
|
|
size_t cache_token_size = tokens.size();
|
|
if (!tokens.empty()) {
|
|
if (ret->params.think_tokens.exclude) {
|
|
server_tokens cache_exclude_think = tokens.get_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
|
|
server_tokens prompt_exclude_think = task.tokens.get_tokens_exclude_think(ret->ctx, ret->params.think_tokens);
|
|
|
|
cache_token_size = cache_exclude_think.size();
|
|
f_keep = calculate_slot_f_keep(*ret, ret->ctx, cache_exclude_think, prompt_exclude_think);
|
|
}
|
|
else {
|
|
f_keep = calculate_slot_f_keep(*ret, ret->ctx, tokens, task.tokens);
|
|
}
|
|
// if we are about to lose a large portion of the existing context - save it in the prompt cache
|
|
if (f_keep < cache_ram_similarity) {
|
|
update_cache = true;
|
|
}
|
|
}
|
|
|
|
update_cache = update_cache && prompt_cache;
|
|
// cache prompts only for completion tasks
|
|
update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
|
|
|
|
// don't update the cache if the slot's context is above cache_ram_n_min
|
|
update_cache = update_cache && cache_token_size >= cache_ram_n_min;
|
|
|
|
LLAMA_LOG_INFO("======== Prompt cache: cache size: %d, n_keep: %d, n_discarded_prompt: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n",
|
|
(int)tokens.size(), ret->n_kept_prompt, ret->n_discarded_prompt, cache_ram_n_min, f_keep, cache_ram_similarity);
|
|
if (update_cache) {
|
|
const int64_t t_start = ggml_time_us();
|
|
LLAMA_LOG_INFO("updating prompt cache\n");
|
|
// copy cache tokens
|
|
copy_data_to_cached_prompt(tokens, *ret);
|
|
|
|
ret->prompt_save(*prompt_cache);
|
|
LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
|
|
}
|
|
// has prompts saved earlier to load
|
|
if (prompt_cache && !prompt_cache->states.empty()) {
|
|
const int64_t t_start = ggml_time_us();
|
|
copy_data_to_cached_prompt(tokens, *ret);
|
|
|
|
ret->prompt_load(*prompt_cache, task.tokens);
|
|
prompt_cache->update();
|
|
|
|
ret->cache_tokens = ret->server_cached_prompt.tokens.clone(); // recover cache tokens
|
|
ret->n_discarded_prompt = ret->server_cached_prompt.n_discarded_prompt;
|
|
ret->n_kept_prompt = ret->server_cached_prompt.n_kept_prompt;
|
|
|
|
LLAMA_LOG_INFO("prompt cache load took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
|
|
}
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
int32_t server_context::populate_vocab_pieces() {
|
|
const int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
|
if (vocab_pieces.size() == n_vocab) {
|
|
return n_vocab;
|
|
}
|
|
vocab_pieces.clear();
|
|
vocab_pieces.reserve(n_vocab);
|
|
for (int32_t id = 0; id < n_vocab; ++id) {
|
|
vocab_pieces.push_back(common_token_to_piece(ctx, id, true));
|
|
}
|
|
return n_vocab;
|
|
}
|
|
|
|
bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) {
|
|
slot_params defaults;
|
|
defaults.speculative = params_base.speculative;
|
|
|
|
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
|
common_params_sampling default_sparams = params_base.sparams;
|
|
auto& data = task.data;
|
|
const llama_vocab* vocab = llama_model_get_vocab(model);
|
|
if (data.count("__oaicompat") != 0) {
|
|
slot.oaicompat = true;
|
|
slot.oaicompat_model = task.params.oaicompat_model;
|
|
}
|
|
else {
|
|
slot.oaicompat = false;
|
|
slot.oaicompat_model = "";
|
|
}
|
|
slot.params.oaicompat = task.params.oaicompat;
|
|
slot.params.oaicompat_cmpl_id =task.params.oaicompat_cmpl_id;
|
|
|
|
slot.oai_resp_thinking_block_started = false;
|
|
slot.oai_resp_text_block_started = false;
|
|
slot.oai_resp_id = "resp_" + random_string();
|
|
slot.oai_resp_reasoning_id = "rs_" + random_string();
|
|
slot.oai_resp_message_id = "msg_" + random_string();
|
|
slot.oai_resp_fc_id.clear();
|
|
slot.params.timings_per_token = json_value(data, "timings_per_token", false);
|
|
slot.params.stream = json_value(data, "stream", false);
|
|
auto stream_opt = json_value(data, "stream_options", json::object());
|
|
slot.params.include_usage = json_value(stream_opt, "include_usage", false);
|
|
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
|
|
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
|
|
slot.saturate_predict = json_value(data, "saturate_predict", false);
|
|
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
|
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
|
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
|
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
|
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
|
|
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
|
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
|
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
|
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
|
|
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
|
|
slot.sparams.top_n_sigma = json_value(data, "top_n_sigma", default_sparams.top_n_sigma);
|
|
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
|
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
|
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
|
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
|
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
|
|
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
|
|
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
|
|
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
|
|
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
|
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
|
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
|
slot.sparams.adaptive_target = json_value(data, "adaptive_target", default_sparams.adaptive_target);
|
|
slot.sparams.adaptive_decay = json_value(data, "adaptive_decay", default_sparams.adaptive_decay);
|
|
slot.sparams.adaptive_updt_w_cur = json_value(data, "adaptive_updt_w_cur", default_sparams.adaptive_updt_w_cur);
|
|
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
|
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
|
slot.params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
|
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
|
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
|
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
|
|
|
slot.params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
|
|
|
|
// speculative decoding parameters
|
|
slot.params.speculative.n_max = json_value(data, "speculative.n_max", params_base.speculative.n_max);
|
|
slot.params.speculative.n_min = json_value(data, "speculative.n_min", params_base.speculative.n_min);
|
|
slot.params.speculative.p_min = json_value(data, "speculative.p_min", params_base.speculative.p_min);
|
|
|
|
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
|
|
slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0);
|
|
slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
|
|
|
|
slot.params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
|
|
|
|
// Clamp speculative parameters
|
|
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
|
|
slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0);
|
|
slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
|
|
|
|
slot.params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n);
|
|
slot.params.speculative.ngram_size_m = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m);
|
|
slot.params.speculative.ngram_min_hits = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits);
|
|
|
|
slot.params.speculative.ngram_size_n = std::max(std::min(1, (int)slot.params.speculative.ngram_size_n), 1024);
|
|
slot.params.speculative.ngram_size_m = std::max(std::min(1, (int)slot.params.speculative.ngram_size_m), 1024);
|
|
slot.params.speculative.ngram_min_hits = std::max(std::min(1, (int)slot.params.speculative.ngram_min_hits), 1024);
|
|
|
|
|
|
if (slot.sparams.penalty_last_n < -1) {
|
|
throw std::runtime_error("Error: repeat_last_n must be >= -1");
|
|
}
|
|
|
|
if (slot.sparams.dry_penalty_last_n < -1) {
|
|
throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
|
|
}
|
|
|
|
if (slot.sparams.penalty_last_n == -1) {
|
|
// note: should be the slot's context and not the full context, but it's ok
|
|
slot.sparams.penalty_last_n = llama_n_ctx(ctx);
|
|
}
|
|
|
|
if (slot.sparams.dry_penalty_last_n == -1) {
|
|
slot.sparams.dry_penalty_last_n = llama_n_ctx(ctx);
|
|
|
|
}
|
|
if (slot.sparams.dry_base < 1.0f)
|
|
{
|
|
slot.sparams.dry_base = default_sparams.dry_base;
|
|
}
|
|
|
|
// sequence breakers for DRY
|
|
{
|
|
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
|
|
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
|
|
|
|
if (data.contains("dry_sequence_breakers")) {
|
|
slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
|
if (slot.sparams.dry_sequence_breakers.empty()) {
|
|
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
// process "json_schema" and "grammar"
|
|
if (data.contains("json_schema") && !data.contains("grammar")) {
|
|
try {
|
|
auto schema = json_value(data, "json_schema", json::object());
|
|
LLAMA_LOG_DEBUG("JSON schema: %s\n", schema.dump(2).c_str());
|
|
std::string grammar_str = json_schema_to_grammar(schema);
|
|
SRV_DBG("Converted grammar: %s\n", grammar_str.c_str());
|
|
slot.sparams.grammar = { COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, std::move(grammar_str) };
|
|
}
|
|
catch (const std::exception& e) {
|
|
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
|
}
|
|
}
|
|
else {
|
|
slot.sparams.grammar = default_sparams.grammar;
|
|
std::string grammar_str = json_value(data, "grammar", std::string());
|
|
if (!grammar_str.empty()) {
|
|
// grammar_type key is set by the server when converting chat template grammars
|
|
std::string grammar_type = json_value(data, "grammar_type", std::string());
|
|
if (grammar_type == "tool_calls") {
|
|
slot.sparams.grammar = { COMMON_GRAMMAR_TYPE_TOOL_CALLS, std::move(grammar_str) };
|
|
} else {
|
|
// explicit grammar from the user (API field "grammar")
|
|
slot.sparams.grammar = { COMMON_GRAMMAR_TYPE_USER, std::move(grammar_str) };
|
|
}
|
|
LLAMA_LOG_DEBUG("Grammar (%s): %s\n", grammar_type.c_str(), common_grammar_value(slot.sparams.grammar).c_str());
|
|
}
|
|
slot.sparams.grammar_lazy = json_value(data, "grammar_lazy", default_sparams.grammar_lazy);
|
|
LLAMA_LOG_DEBUG("Grammar lazy: %s\n", slot.sparams.grammar_lazy ? "true" : "false");
|
|
}
|
|
|
|
if (slot.params.cache_prompt && slot.ga_n != 1) {
|
|
LOG_WARNING("cache_prompt is not supported with group-attention", {});
|
|
slot.params.cache_prompt = false;
|
|
}
|
|
|
|
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
|
// Might be better to reject the request with a 400 ?
|
|
LOG_WARNING("Max tokens to predict exceeds server configuration", {
|
|
{"params.n_predict", slot.params.n_predict},
|
|
{"slot.n_predict", slot.n_predict},
|
|
});
|
|
slot.params.n_predict = slot.n_predict;
|
|
}
|
|
|
|
// infill
|
|
slot.params.input_prefix = json_value(data, "input_prefix", defaults.input_prefix);
|
|
slot.params.input_suffix = json_value(data, "input_suffix", defaults.input_suffix);
|
|
|
|
// get prompt
|
|
if (!task.infill) {
|
|
slot.prompt_tokens = std::move(task.tokens);
|
|
|
|
const auto & prompt = data.find("prompt");
|
|
if (prompt != data.end()) {
|
|
if (prompt->is_string() ||
|
|
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
|
|
slot.prompt = *prompt;
|
|
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) {
|
|
slot.prompt = *prompt;
|
|
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
|
slot.prompt = prompt->at(0);
|
|
}
|
|
}
|
|
}
|
|
|
|
// penalize user-provided tokens
|
|
{
|
|
slot.sparams.penalty_prompt_tokens.clear();
|
|
slot.sparams.use_penalty_prompt_tokens = false;
|
|
|
|
const auto& penalty_prompt = data.find("penalty_prompt");
|
|
|
|
if (penalty_prompt != data.end()) {
|
|
if (penalty_prompt->is_string()) {
|
|
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
|
|
slot.sparams.penalty_prompt_tokens = common_tokenize(model, penalty_prompt_string, false);
|
|
|
|
if (slot.params.n_predict > 0) {
|
|
slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
|
|
}
|
|
slot.sparams.use_penalty_prompt_tokens = true;
|
|
|
|
LOG_VERBOSE("penalty_prompt_tokens", {
|
|
{"id_slot", slot.id},
|
|
{"tokens", slot.sparams.penalty_prompt_tokens},
|
|
});
|
|
}
|
|
else if (penalty_prompt->is_array()) {
|
|
const auto n_tokens = penalty_prompt->size();
|
|
slot.sparams.penalty_prompt_tokens.clear();
|
|
slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
|
|
|
|
const int n_vocab = llama_n_vocab(model);
|
|
for (const auto& penalty_token : *penalty_prompt) {
|
|
if (penalty_token.is_number_integer()) {
|
|
const auto tok = penalty_token.get<llama_token>();
|
|
if (tok >= 0 && tok < n_vocab) {
|
|
slot.sparams.penalty_prompt_tokens.push_back(tok);
|
|
}
|
|
}
|
|
}
|
|
slot.sparams.use_penalty_prompt_tokens = true;
|
|
|
|
LOG_VERBOSE("penalty_prompt_tokens", {
|
|
{"id_slot", slot.id},
|
|
{"tokens", slot.sparams.penalty_prompt_tokens},
|
|
});
|
|
}
|
|
}
|
|
}
|
|
{
|
|
auto it = data.find("chat_format");
|
|
if (it != data.end()) {
|
|
slot.params.chat_parser_params.format = static_cast<common_chat_format>(it->get<int>());
|
|
LLAMA_LOG_DEBUG("Chat format: %s\n", common_chat_format_name(slot.params.chat_parser_params.format));
|
|
}
|
|
else {
|
|
slot.params.chat_parser_params.format = defaults.chat_parser_params.format;
|
|
}
|
|
common_reasoning_format reasoning_format = params_base.reasoning_format;
|
|
if (data.contains("reasoning_format")) {
|
|
reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
|
|
}
|
|
slot.params.chat_parser_params.reasoning_format = reasoning_format;
|
|
slot.params.chat_parser_params.reasoning_in_content = slot.params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
|
|
slot.params.chat_parser_params.generation_prompt = json_value(data, "generation_prompt", std::string());
|
|
slot.sparams.generation_prompt = slot.params.chat_parser_params.generation_prompt;
|
|
LLAMA_LOG_DEBUG("Generation prompt: '%s'\n", slot.params.chat_parser_params.generation_prompt.c_str());
|
|
slot.params.chat_parser_params.parse_tool_calls = json_value(data, "parse_tool_calls", false);
|
|
if (data.contains("chat_parser")) {
|
|
slot.params.chat_parser_params.parser.load(data.at("chat_parser").get<std::string>());
|
|
}
|
|
}
|
|
{
|
|
|
|
const auto preserved_tokens = data.find("preserved_tokens");
|
|
if (preserved_tokens != data.end()) {
|
|
slot.sparams.preserved_tokens.clear();
|
|
for (const auto& t : *preserved_tokens) {
|
|
auto ids = common_tokenize(model, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
|
|
if (ids.size() == 1) {
|
|
LOG("Preserved token: %d\n", ids[0]);
|
|
slot.sparams.preserved_tokens.insert(ids[0]);
|
|
}
|
|
else {
|
|
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
|
|
LOG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
|
|
}
|
|
}
|
|
}
|
|
const auto grammar_triggers = data.find("grammar_triggers");
|
|
if (grammar_triggers != data.end()) {
|
|
slot.sparams.grammar_triggers.clear();
|
|
for (const auto& t : *grammar_triggers) {
|
|
server_grammar_trigger ct(t);
|
|
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
|
const auto& word = ct.value.value;
|
|
auto ids = common_tokenize(model, word, /* add_special= */ false, /* parse_special= */ true);
|
|
if (ids.size() == 1) {
|
|
auto token = ids[0];
|
|
if (std::find(slot.sparams.preserved_tokens.begin(), slot.sparams.preserved_tokens.end(), (llama_token)token) == slot.sparams.preserved_tokens.end()) {
|
|
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
|
|
}
|
|
LOG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
|
|
common_grammar_trigger trigger;
|
|
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
|
|
trigger.value = word;
|
|
trigger.token = token;
|
|
slot.sparams.grammar_triggers.push_back(std::move(trigger));
|
|
}
|
|
else {
|
|
LOG("Grammar trigger word: `%s`\n", word.c_str());
|
|
slot.sparams.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word });
|
|
}
|
|
}
|
|
else {
|
|
//slot.sparams.grammar_triggers.push_back(ct);
|
|
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
|
|
LLAMA_LOG_DEBUG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
|
|
}
|
|
else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
|
|
LLAMA_LOG_DEBUG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
|
|
}
|
|
else {
|
|
throw std::runtime_error("Unknown grammar trigger type");
|
|
}
|
|
slot.sparams.grammar_triggers.emplace_back(std::move(ct.value));
|
|
}
|
|
}
|
|
}
|
|
|
|
if (slot.sparams.grammar_lazy && slot.sparams.grammar_triggers.empty()) {
|
|
throw std::runtime_error("Error: no triggers set for lazy grammar!");
|
|
}
|
|
}
|
|
|
|
// Parse reasoning budget sampler parameters
|
|
{
|
|
const int32_t budget = json_value(data, "reasoning_budget_tokens", (int32_t)-1);
|
|
const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
|
|
const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string());
|
|
const auto message = json_value(data, "reasoning_budget_message", std::string());
|
|
slot.sparams.reasoning_budget_tokens = budget;
|
|
|
|
if (!start_tag.empty()) {
|
|
slot.sparams.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
|
|
}
|
|
if (!end_tag.empty()) {
|
|
slot.sparams.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true);
|
|
slot.sparams.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
|
|
|
|
SRV_DBG("reasoning budget: tokens=%d, generation_prompt='%s', start=%zu toks, end=%zu toks, forced=%zu toks\n",
|
|
budget, slot.sparams.generation_prompt.c_str(),
|
|
slot.sparams.reasoning_budget_start.size(),
|
|
slot.sparams.reasoning_budget_end.size(),
|
|
slot.sparams.reasoning_budget_forced.size());
|
|
}
|
|
}
|
|
|
|
{ // apply logit bias
|
|
const auto& logit_bias = data.find("logit_bias");
|
|
if (logit_bias != data.end() && (logit_bias->is_object() || logit_bias->is_array())) {
|
|
slot.sparams.logit_bias.clear(); // only clear if user sets it
|
|
}
|
|
if (logit_bias != data.end() && logit_bias->is_array()) {
|
|
const int n_vocab = llama_n_vocab(model);
|
|
for (const auto& el : *logit_bias) {
|
|
// TODO: we may want to throw errors here, in case "el" is incorrect
|
|
if (el.is_array() && el.size() == 2) {
|
|
float bias;
|
|
if (el[1].is_number()) {
|
|
bias = el[1].get<float>();
|
|
}
|
|
else if (el[1].is_boolean() && !el[1].get<bool>()) {
|
|
bias = -INFINITY;
|
|
}
|
|
else {
|
|
continue;
|
|
}
|
|
|
|
if (el[0].is_number_integer()) {
|
|
llama_token tok = el[0].get<llama_token>();
|
|
if (tok >= 0 && tok < n_vocab) {
|
|
slot.sparams.logit_bias[tok] = bias;
|
|
}
|
|
}
|
|
else if (el[0].is_string()) {
|
|
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
|
|
for (auto tok : toks) {
|
|
slot.sparams.logit_bias[tok] = bias;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else if (logit_bias != data.end() && logit_bias->is_object()) {
|
|
const int n_vocab = llama_vocab_n_tokens(vocab);
|
|
for (const auto& el : logit_bias->items()) {
|
|
float bias;
|
|
const auto& key = el.key();
|
|
const auto& value = el.value();
|
|
if (value.is_number()) {
|
|
bias = value.get<float>();
|
|
}
|
|
else if (value.is_boolean() && !value.get<bool>()) {
|
|
bias = -INFINITY;
|
|
}
|
|
else {
|
|
continue;
|
|
}
|
|
|
|
char* end;
|
|
llama_token tok = strtol(key.c_str(), &end, 10);
|
|
if (*end == 0) {
|
|
if (tok >= 0 && tok < n_vocab) {
|
|
slot.sparams.logit_bias[tok] = bias;
|
|
}
|
|
}
|
|
else {
|
|
auto toks = common_tokenize(model, key, false);
|
|
for (auto tok : toks) {
|
|
slot.sparams.logit_bias[tok] = bias;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (json_value(data, "ignore_eos", false) && has_eos_token) {
|
|
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
|
}
|
|
|
|
}
|
|
|
|
{
|
|
// ban string
|
|
int32_t banbuffer_size = json_value(data, "banbuffer_size", 0);
|
|
slot.n_buffer = 0; // Ensure buffer calculation starts fresh for this slot
|
|
slot.rewind_count_max = json_value(data, "rewind_count_max", -1);
|
|
|
|
const auto& banned_strings = data.find("banned_strings");
|
|
if (banned_strings != data.end() && banned_strings->is_array()) {
|
|
slot.ban_phrases.clear();
|
|
for (const auto& val : data["banned_strings"]) {
|
|
if (val.is_string()) {
|
|
std::string s = val.get<std::string>();
|
|
if (!s.empty()) {
|
|
s = string_lower(s);
|
|
// Use string length instead of token count
|
|
if (s.length() > slot.n_buffer) {
|
|
slot.n_buffer = s.length();
|
|
}
|
|
slot.ban_phrases.push_back(s);
|
|
}
|
|
}
|
|
}
|
|
std::sort(slot.ban_phrases.begin(), slot.ban_phrases.end(), [](const std::string& a, const std::string& b) {
|
|
return a.length() > b.length();
|
|
});
|
|
} else if (params_base.ban_phrases.size() > 0) {
|
|
if (params_base.n_buffer == 0) {
|
|
slot.ban_phrases.clear();
|
|
std::sort(params_base.ban_phrases.begin(), params_base.ban_phrases.end(), [](const std::string & a, const std::string & b) {
|
|
return a.length() > b.length();
|
|
});
|
|
for (auto & val : params_base.ban_phrases) {
|
|
if (!val.empty()) {
|
|
val = string_lower(val);
|
|
// Use string length instead of token count
|
|
if (val.length() > slot.n_buffer) {
|
|
slot.n_buffer = val.length();
|
|
}
|
|
slot.ban_phrases.push_back(val);
|
|
}
|
|
}
|
|
params_base.n_buffer = slot.n_buffer + 1; // buffer is longest string + 1
|
|
} else {
|
|
slot.ban_phrases = params_base.ban_phrases;
|
|
slot.n_buffer = params_base.n_buffer;
|
|
}
|
|
}
|
|
|
|
// ban regex
|
|
slot.ban_regex.clear();
|
|
const auto& banned_regex = data.find("banned_regex");
|
|
if (banned_regex != data.end() && banned_regex->is_array()) {
|
|
for (const auto& val : data["banned_regex"]) {
|
|
if (val.is_string()) {
|
|
std::string s = val.get<std::string>();
|
|
if (!s.empty()) {
|
|
try {
|
|
std::regex re(s);
|
|
slot.ban_regex.push_back(s);
|
|
if (s.length() > slot.n_buffer) {
|
|
slot.n_buffer = s.length();
|
|
}
|
|
} catch (const std::regex_error& e) {
|
|
send_error(task, "Invalid regex in banned_regex: " + s, ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ban regex case insensitive
|
|
slot.ban_regex_ci.clear();
|
|
const auto& banned_regex_ci = data.find("banned_regex_case_insensitive");
|
|
if (banned_regex_ci != data.end() && banned_regex_ci->is_array()) {
|
|
for (const auto& val : data["banned_regex_case_insensitive"]) {
|
|
if (val.is_string()) {
|
|
std::string s = val.get<std::string>();
|
|
if (!s.empty()) {
|
|
try {
|
|
std::regex re(s, std::regex_constants::icase);
|
|
slot.ban_regex_ci.push_back(s);
|
|
if (s.length() > slot.n_buffer) {
|
|
slot.n_buffer = s.length();
|
|
}
|
|
} catch (const std::regex_error& e) {
|
|
send_error(task, "Invalid regex in banned_regex_case_insensitive: " + s, ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (banned_regex_ci != data.end() || banned_regex != data.end() || banned_strings != data.end()) {
|
|
if (banbuffer_size > 0) {
|
|
slot.n_buffer = banbuffer_size;
|
|
} else {
|
|
slot.n_buffer = slot.n_buffer + 1; // buffer is longest string/regex + 1
|
|
}
|
|
}
|
|
slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore
|
|
slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias);
|
|
slot.banned_n = json_value(data, "banned_n", params_base.banned_n);
|
|
}
|
|
|
|
do // populate allowlist biases
|
|
{
|
|
// TODO: JSON parsing for rules and keywords
|
|
slot.allow_ruless = params_base.allow_ruless;
|
|
if (slot.allow_ruless.size() == 0) {
|
|
slot.allow_biasess.clear();
|
|
break;
|
|
}
|
|
slot.allow_kws = params_base.allow_kws;
|
|
|
|
slot.allow_pieces = params_base.allow_pieces;
|
|
const auto& allowlist_piece_array = data.find("allowlist_piece_array");
|
|
if (allowlist_piece_array != data.end() && allowlist_piece_array->is_array()) {
|
|
slot.allow_pieces.clear();
|
|
for (const auto& piece: *allowlist_piece_array) {
|
|
if (piece.is_string()) {
|
|
slot.allow_pieces.push_back(piece.get<std::string>());
|
|
}
|
|
}
|
|
}
|
|
|
|
slot.allow_kw_delay = json_value(data, "allowlist_keyword_delay", params_base.allow_kw_delay);
|
|
// end of allowlist criteria update
|
|
|
|
const int32_t n_vocab = populate_vocab_pieces();
|
|
|
|
std::unordered_set<llama_token> allow_settoken;
|
|
for (const auto& piece: slot.allow_pieces) {
|
|
for (const auto token: common_tokenize(model, piece, false, true)) {
|
|
allow_settoken.insert(token);
|
|
}
|
|
}
|
|
|
|
auto n_rules = slot.allow_ruless.size();
|
|
if (n_rules > slot.allow_kws.size() + 1) {
|
|
// one more rules than keyword, last rules do not expire
|
|
n_rules = slot.allow_kws.size() + 1;
|
|
slot.allow_ruless.resize(n_rules);
|
|
} else if (n_rules < slot.allow_kws.size()) {
|
|
// every rules expire
|
|
slot.allow_kws.resize(n_rules);
|
|
}
|
|
slot.allow_biasess.resize(n_rules);
|
|
|
|
for (size_t i = 0; i < n_rules; ++i) {
|
|
const auto& rules = slot.allow_ruless[i];
|
|
if ((i < slot.allow_ruless_prev.size()) && (rules == slot.allow_ruless_prev[i])) {
|
|
continue;
|
|
}
|
|
LLAMA_LOG_DEBUG("%s: allowlist %zu is new\n", __func__, i);
|
|
|
|
auto& biases = slot.allow_biasess[i];
|
|
biases.resize(n_vocab);
|
|
|
|
std::vector<uint32_t> cpts;
|
|
std::vector<std::string> scripts;
|
|
for (size_t id = 0; id < n_vocab; ++id) {
|
|
const size_t n_cpt = llama_fill_from_utf8(&vocab_pieces[id], &cpts, &scripts);
|
|
float bias = -INFINITY;
|
|
|
|
// each codepoint must be found in
|
|
for (size_t j = 0; j < n_cpt; ++j) {
|
|
bool in_rule = false;
|
|
|
|
// at least one rule
|
|
for (const auto& rule: rules) {
|
|
const bool in_range = (std::get<0>(rule) <= cpts[j]) && (cpts[j] <= std::get<1>(rule));
|
|
in_rule = in_range && ((std::get<2>(rule) == "*") || std::get<2>(rule) == scripts[j]);
|
|
if (in_rule) {
|
|
// earlier rule has higher priority
|
|
bias = std::max(bias, std::get<3>(rule));
|
|
break;
|
|
}
|
|
}
|
|
if (!in_rule) {
|
|
if ((scripts[j] == "common") || (scripts[j] == "inherited")) {
|
|
// for common or inherited codepoints (e.g. whitespace), defer to other codepoints in the token
|
|
continue;
|
|
}
|
|
|
|
// to shadow realm
|
|
bias = -INFINITY;
|
|
break;
|
|
}
|
|
}
|
|
biases[id] = bias;
|
|
}
|
|
|
|
float max_bias = -INFINITY;
|
|
for (const auto& rule: rules) {
|
|
max_bias = std::max(max_bias, std::get<3>(rule));
|
|
}
|
|
for (const auto token: allow_settoken) {
|
|
biases[token] = max_bias;
|
|
}
|
|
}
|
|
} while (false);
|
|
slot.allow_ruless_prev = slot.allow_ruless;
|
|
|
|
if (llama_model_has_recurrent(llama_get_model(slot.ctx))) {
|
|
params_base.can_ban_phrases = false;
|
|
bool do_checkpoint = params_base.ctx_checkpoints_n > 0;
|
|
// make checkpoints only for completion tasks
|
|
do_checkpoint = do_checkpoint && task.type == SERVER_TASK_TYPE_COMPLETION;
|
|
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
|
// checkpoints are created only if:
|
|
// - the model architecture is marked as recurrent or hybrid
|
|
//
|
|
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
|
params_base.do_checkpoint = do_checkpoint;
|
|
if (slot.n_buffer != 0) {
|
|
LLAMA_LOG_WARN("banned strings is not supported by recurrent model, it will be disabled.\n");
|
|
}
|
|
if (params_base.ctx_shift) {
|
|
params_base.ctx_shift = false;
|
|
LOG_WARNING("%s\n", "ctx_shift is not supported by recurrent model, it will be disabled");
|
|
}
|
|
}
|
|
{
|
|
const auto& stop = data.find("stop");
|
|
if (stop != data.end() && stop->is_array()) {
|
|
slot.params.antiprompt.clear();
|
|
for (const auto& word : *stop) {
|
|
if (!word.empty()) {
|
|
slot.params.antiprompt.push_back(word);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
{
|
|
const auto samplers = data.find("samplers");
|
|
if (samplers != data.end()) {
|
|
if (samplers->is_array()) {
|
|
slot.sparams.samplers_sequence = llama_sampling_types_from_names(*samplers, false);
|
|
}
|
|
else if (samplers->is_string()) {
|
|
slot.sparams.samplers_sequence = llama_sampling_types_from_chars(samplers->get<std::string>());
|
|
}
|
|
else {
|
|
slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
|
|
}
|
|
}
|
|
}
|
|
|
|
{
|
|
try
|
|
{
|
|
if (slot.ctx_sampling != nullptr) {
|
|
common_sampler_free(slot.ctx_sampling);
|
|
}
|
|
slot.ctx_sampling = common_sampler_init(model, slot.sparams);
|
|
}
|
|
catch (std::exception & e) {
|
|
std::string err_msg = std::string("Failed to initialize samplers: ") + e.what();
|
|
send_error(task, err_msg, ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
if (slot.ctx_sampling == nullptr) {
|
|
// for now, the only error that may happen here is invalid grammar
|
|
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
slot.command = SLOT_COMMAND_LOAD_PROMPT;
|
|
// slot.prompt_tokens.clear();
|
|
|
|
LOG_INFO("slot is processing task", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
});
|
|
slot.task = std::make_unique<const server_task>(std::move(task));
|
|
return true;
|
|
}
|
|
|
|
void server_context::kv_cache_clear() {
|
|
LOG_VERBOSE("clearing KV cache", {});
|
|
|
|
// clear the entire KV cache
|
|
llama_kv_cache_clear(ctx);
|
|
clean_kv_cache = false;
|
|
}
|
|
|
|
void server_context::system_prompt_update() {
|
|
LOG_VERBOSE("system prompt update", {
|
|
{"system_prompt", system_prompt},
|
|
});
|
|
|
|
kv_cache_clear();
|
|
system_tokens.clear();
|
|
|
|
if (!system_prompt.empty()) {
|
|
system_tokens = ::common_tokenize(ctx, system_prompt, true);
|
|
|
|
const int32_t n_batch = llama_n_batch(ctx);
|
|
const int32_t n_tokens_prompt = system_tokens.size();
|
|
|
|
for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
|
|
const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
|
|
|
|
common_batch_clear(batch);
|
|
|
|
for (int32_t j = 0; j < n_tokens; ++j) {
|
|
common_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
|
|
}
|
|
|
|
if (llama_decode(ctx, batch) != 0) {
|
|
LOG_ERROR("llama_decode() failed", {});
|
|
return;
|
|
}
|
|
}
|
|
|
|
// assign the system KV cache to all parallel sequences
|
|
for (int32_t i = 1; i <= params_base.n_parallel; ++i) {
|
|
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
|
}
|
|
}
|
|
|
|
system_need_update = false;
|
|
}
|
|
|
|
bool server_context::system_prompt_set(const std::string& sys_prompt) {
|
|
system_prompt = sys_prompt;
|
|
|
|
LOG_VERBOSE("system prompt process", {
|
|
{"system_prompt", system_prompt},
|
|
});
|
|
|
|
// release all slots
|
|
for (server_slot& slot : slots) {
|
|
slot.release();
|
|
}
|
|
|
|
system_need_update = true;
|
|
return true;
|
|
}
|
|
|
|
// keep in sync with process_token(completion_token_output& result, server_slot& slot)
|
|
bool server_context::has_next_token(const completion_token_output& result, server_slot& slot) {
|
|
bool next = true;
|
|
//std::string generate_text = slot.generated_text + result.text_to_send;
|
|
//bool incomplete = validate_utf8(generate_text) < generate_text.size();
|
|
//if (incomplete) {
|
|
// next = true;
|
|
//}
|
|
if (slot.n_decoded > 0 && !slot.has_budget(params_base)) {
|
|
next = false;
|
|
}
|
|
if (llama_token_is_eog(model, result.tok)) {
|
|
next = false;
|
|
}
|
|
auto n_ctx_train = llama_n_ctx_train(model);
|
|
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
|
|
&& slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
|
|
next = false;
|
|
}
|
|
return next;
|
|
}
|
|
|
|
|
|
bool server_context::process_token(completion_token_output& result, server_slot& slot) {
|
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
|
const std::string token_str = result.text_to_send;
|
|
slot.sampled = result.tok;
|
|
|
|
// search stop word and delete it
|
|
slot.last_gentxt_size = slot.generated_text.size();
|
|
slot.generated_text += token_str;
|
|
slot.has_next_token = true;
|
|
|
|
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
|
|
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
|
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
|
}
|
|
|
|
// check if there is incomplete UTF-8 character at the end
|
|
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
|
|
|
|
if (!incomplete) {
|
|
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
|
|
|
const std::string str_test = slot.generated_text.substr(pos);
|
|
bool send_text = true;
|
|
|
|
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
|
|
if (stop_pos != std::string::npos) {
|
|
slot.generated_text.erase(
|
|
slot.generated_text.begin() + pos + stop_pos,
|
|
slot.generated_text.end());
|
|
pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
|
}
|
|
else if (slot.has_next_token && !llama_token_is_eog(model, result.tok)) {
|
|
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
|
|
send_text = stop_pos == std::string::npos;
|
|
}
|
|
|
|
// check if there is any token to predict
|
|
if (send_text) {
|
|
// no send the stop word in the response
|
|
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
|
slot.n_sent_text += result.text_to_send.size();
|
|
// add the token to slot queue and cache
|
|
}
|
|
else {
|
|
result.text_to_send = "";
|
|
}
|
|
|
|
slot.add_token_string(result);
|
|
if (slot.params.stream) {
|
|
send_partial_response(slot, result);
|
|
}
|
|
}
|
|
|
|
if (incomplete) {
|
|
slot.has_next_token = true;
|
|
}
|
|
|
|
// check the limits
|
|
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
|
|
slot.stopped_limit = true;
|
|
slot.has_next_token = false;
|
|
|
|
LOG_VERBOSE("stopped by limit", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_decoded", slot.n_decoded},
|
|
{"n_predict", slot.params.n_predict},
|
|
});
|
|
}
|
|
|
|
if (llama_token_is_eog(model, result.tok)) {
|
|
slot.stopped_eos = true;
|
|
slot.has_next_token = false;
|
|
|
|
LOG_VERBOSE("eos token found", {});
|
|
}
|
|
|
|
auto n_ctx_train = llama_n_ctx_train(model);
|
|
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
|
|
&& slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
|
|
LOG_WARNING("n_predict is not set and self-context extend is disabled."
|
|
" Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", {
|
|
{ "id_slot", slot.id },
|
|
{ "params.n_predict", slot.params.n_predict },
|
|
{ "slot.n_prompt_tokens", slot.n_prompt_tokens },
|
|
{ "slot.n_decoded", slot.n_decoded },
|
|
{ "slot.n_predict", slot.n_predict },
|
|
{ "n_slots", params_base.n_parallel },
|
|
{ "slot.n_ctx", slot.n_ctx },
|
|
{ "n_ctx", n_ctx },
|
|
{ "n_ctx_train", n_ctx_train },
|
|
{ "ga_n", slot.ga_n },
|
|
});
|
|
slot.truncated = true;
|
|
slot.stopped_limit = true;
|
|
slot.has_next_token = false; // stop prediction
|
|
}
|
|
log_text(params_base, "token:"+result.text_to_send);
|
|
LOG_VERBOSE("next token", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"token", result.tok},
|
|
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
|
|
{"has_next_token", slot.has_next_token},
|
|
{"n_remain", slot.n_remaining},
|
|
{"n_decoded", slot.n_decoded},
|
|
{"stopped_eos", slot.stopped_eos},
|
|
{"stopped_word", slot.stopped_word},
|
|
{"stopped_limit", slot.stopped_limit},
|
|
{"stopping_word", slot.stopping_word},
|
|
});
|
|
|
|
return slot.has_next_token; // continue
|
|
}
|
|
|
|
void server_context::populate_token_probs(const server_slot& slot, completion_token_output& result, bool post_sampling, bool special, int idx) {
|
|
size_t n_probs = slot.sparams.n_probs;
|
|
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
|
|
if (post_sampling) {
|
|
const auto* cur_p = common_sampler_get_candidates(slot.ctx_sampling);
|
|
const size_t max_probs = cur_p->size;
|
|
|
|
// set probability for sampled token
|
|
for (size_t i = 0; i < max_probs; i++) {
|
|
if (cur_p->data[i].id == result.tok) {
|
|
result.prob = cur_p->data[i].p;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// set probability for top n_probs tokens
|
|
result.probs.reserve(max_probs);
|
|
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
|
result.probs.push_back({
|
|
cur_p->data[i].id,
|
|
common_token_to_piece(ctx, cur_p->data[i].id, special),
|
|
cur_p->data[i].p
|
|
});
|
|
}
|
|
}
|
|
else {
|
|
auto&& [sampled_token_p, cur] = get_token_probabilities(ctx, idx, result.tok, n_probs);
|
|
|
|
// set probability for sampled token
|
|
result.prob = sampled_token_p;
|
|
|
|
// set probability for top n_probs tokens
|
|
result.probs.reserve(n_probs);
|
|
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
|
result.probs.push_back({
|
|
cur[i].id,
|
|
common_token_to_piece(ctx, cur[i].id, special),
|
|
cur[i].p
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
json server_context::get_formated_generation(const server_slot& slot) const {
|
|
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
|
|
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
|
|
|
std::vector<std::string> samplers_sequence;
|
|
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
|
|
for (const auto& sampler_type : slot.sparams.samplers_sequence) {
|
|
samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
|
|
}
|
|
|
|
auto grammar_triggers = json::array();
|
|
for (const auto& trigger : slot.sparams.grammar_triggers) {
|
|
grammar_triggers.push_back(trigger.to_json<json>());
|
|
}
|
|
|
|
return json{
|
|
{"n_ctx", slot.n_ctx},
|
|
{"n_predict", slot.n_predict}, // Server configured n_predict
|
|
{"model", params_base.model_alias},
|
|
{"seed", slot.sparams.seed},
|
|
{"temperature", slot.sparams.temp},
|
|
{"dynatemp_range", slot.sparams.dynatemp_range},
|
|
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
|
|
{"top_k", slot.sparams.top_k},
|
|
{"top_p", slot.sparams.top_p},
|
|
{"min_p", slot.sparams.min_p},
|
|
{"tfs_z", slot.sparams.tfs_z},
|
|
{"typical_p", slot.sparams.typical_p},
|
|
{"repeat_last_n", slot.sparams.penalty_last_n},
|
|
{"repeat_penalty", slot.sparams.penalty_repeat},
|
|
{"presence_penalty", slot.sparams.penalty_present},
|
|
{"frequency_penalty", slot.sparams.penalty_freq},
|
|
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
|
|
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
|
|
{"dry_multiplier", slot.sparams.dry_multiplier},
|
|
{"dry_base", slot.sparams.dry_base},
|
|
{"dry_allowed_length", slot.sparams.dry_allowed_length},
|
|
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
|
|
{"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
|
|
{"mirostat", slot.sparams.mirostat},
|
|
{"mirostat_tau", slot.sparams.mirostat_tau},
|
|
{"mirostat_eta", slot.sparams.mirostat_eta},
|
|
{"adaptive_target", slot.sparams.adaptive_target},
|
|
{"adaptive_decay", slot.sparams.adaptive_decay},
|
|
{"adaptive_updt_w_cur", slot.sparams.adaptive_updt_w_cur},
|
|
{"penalize_nl", slot.sparams.penalize_nl},
|
|
{"stop", slot.params.antiprompt},
|
|
{"max_tokens", slot.params.n_predict}, // User configured n_predict
|
|
{"n_keep", slot.params.n_keep},
|
|
{"n_discard", slot.params.n_discard},
|
|
{"ignore_eos", ignore_eos},
|
|
{"stream", slot.params.stream},
|
|
{"logit_bias", slot.sparams.logit_bias},
|
|
{"n_probs", slot.sparams.n_probs},
|
|
{"min_keep", slot.sparams.min_keep},
|
|
{"grammar", slot.sparams.grammar.grammar},
|
|
{"grammar_triggers", grammar_triggers},
|
|
{"preserved_tokens", slot.sparams.preserved_tokens},
|
|
{"chat_format", common_chat_format_name(slot.params.chat_parser_params.format)},
|
|
{"reasoning_format", common_reasoning_format_name(slot.params.chat_parser_params.reasoning_format)},
|
|
{"reasoning_in_content", slot.params.chat_parser_params.reasoning_in_content},
|
|
{"samplers", samplers_sequence}
|
|
};
|
|
}
|
|
|
|
void server_context::send_error(const server_task& task, const std::string& error, const enum error_type type) {
|
|
send_error(task.id, task.id_multi, error, type);
|
|
}
|
|
|
|
void server_context::send_error(const server_slot& slot, const std::string& error, const enum error_type type) {
|
|
send_error(slot.id_task, slot.id_multi, error, type);
|
|
}
|
|
|
|
void server_context::send_error(const int id_task, const int id_multi, const std::string& error, const enum error_type type ) {
|
|
LOG_ERROR("task error", {
|
|
{"id_multi", id_multi},
|
|
{"id_task", id_task},
|
|
{"error", error},
|
|
});
|
|
|
|
auto res = std::make_unique<server_task_result_error>();
|
|
res->id = id_task;
|
|
res->id_multi = id_multi;
|
|
res->stop = false;
|
|
res->error = true;
|
|
res->err_type = type;
|
|
res->err_msg = error;
|
|
queue_results.send(std::move(res));
|
|
}
|
|
|
|
// if multimodal is enabled, send an error and return false
|
|
bool server_context::check_no_mtmd(const int id_task) {
|
|
if (mctx) {
|
|
int id_multi = 0;
|
|
send_error(id_task, id_multi, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void server_context::send_partial_response(server_slot& slot, completion_token_output tkn) {
|
|
if (slot.task == nullptr) {
|
|
return;
|
|
}
|
|
auto res = std::make_unique<server_task_result_cmpl_partial>();
|
|
res->final_result = false;
|
|
res->id = slot.id_task;
|
|
res->id_multi = slot.id_multi;
|
|
res->index = slot.task->index;
|
|
res->error = false;
|
|
res->stop = false;
|
|
res->stream = slot.params.stream;
|
|
res->content = tkn.text_to_send;
|
|
res->post_sampling_probs = slot.params.post_sampling_probs;
|
|
res->oaicompat = slot.params.oaicompat;
|
|
res->oaicompat_model = slot.task->params.oaicompat_model;
|
|
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
|
res->oai_resp_id = slot.oai_resp_id;
|
|
res->oai_resp_reasoning_id = slot.oai_resp_reasoning_id;
|
|
res->oai_resp_message_id = slot.oai_resp_message_id;
|
|
res->oai_resp_fc_id = slot.oai_resp_fc_id;
|
|
res->n_decoded = slot.n_decoded;
|
|
res->n_prompt_tokens = slot.n_prompt_tokens;
|
|
res->data = json{
|
|
{"content", tkn.text_to_send},
|
|
{"stop", false},
|
|
{"id_slot", slot.id},
|
|
{"multimodal", false}
|
|
};
|
|
slot.update_chat_msg(true, res->oaicompat_msg_diffs);
|
|
|
|
res->anthropic_has_reasoning = !slot.chat_msg.reasoning_content.empty();
|
|
|
|
res->anthropic_thinking_block_started = slot.anthropic_thinking_block_started;
|
|
res->anthropic_text_block_started = slot.anthropic_text_block_started;
|
|
|
|
res->oai_resp_thinking_block_started = slot.oai_resp_thinking_block_started;
|
|
res->oai_resp_text_block_started = slot.oai_resp_text_block_started;
|
|
|
|
for (const auto& diff : res->oaicompat_msg_diffs) {
|
|
if (!diff.reasoning_content_delta.empty() && !slot.anthropic_thinking_block_started) {
|
|
slot.anthropic_thinking_block_started = true;
|
|
}
|
|
if (!diff.content_delta.empty() && !slot.anthropic_text_block_started) {
|
|
slot.anthropic_text_block_started = true;
|
|
}
|
|
if (!diff.reasoning_content_delta.empty() && !slot.oai_resp_thinking_block_started) {
|
|
slot.oai_resp_thinking_block_started = true;
|
|
}
|
|
if (!diff.content_delta.empty() && !slot.oai_resp_text_block_started) {
|
|
slot.oai_resp_text_block_started = true;
|
|
}
|
|
if (!diff.tool_call_delta.name.empty()) {
|
|
slot.oai_resp_fc_id = diff.tool_call_delta.id;
|
|
}
|
|
}
|
|
|
|
// populate res->probs_output
|
|
if (slot.sparams.n_probs > 0) {
|
|
res->probs_output = { tkn }; // copy the token probs
|
|
res->data["completion_probabilities"] = probs_vector_to_json(ctx, res->probs_output);
|
|
}
|
|
|
|
if (slot.oaicompat) {
|
|
res->data["oaicompat_token_ctr"] = slot.n_decoded;
|
|
res->data["model"] = slot.oaicompat_model;
|
|
}
|
|
|
|
// populate timings if this is final response or timings_per_token is enabled
|
|
if (slot.params.timings_per_token) {
|
|
res->timings = slot.get_timings();
|
|
}
|
|
queue_results.send(std::move(res));
|
|
}
|
|
|
|
void server_context::send_final_response(server_slot& slot) {
|
|
auto res = std::make_unique<server_task_result_cmpl_final>();
|
|
res->final_result = true;
|
|
res->id = slot.id_task;
|
|
res->id_multi = slot.id_multi;
|
|
res->index = slot.task->index;
|
|
res->error = false;
|
|
res->stop = true; // to do: set value
|
|
res->stream = slot.params.stream;
|
|
res->include_usage = slot.params.include_usage;
|
|
res->content = slot.generated_text;
|
|
res->timings = slot.get_timings();
|
|
res->post_sampling_probs = slot.params.post_sampling_probs;
|
|
res->oaicompat = slot.params.oaicompat;
|
|
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
|
res->oaicompat_msg = slot.update_chat_msg(false, res->oaicompat_msg_diffs);
|
|
res->oai_resp_id = slot.oai_resp_id;
|
|
res->oai_resp_reasoning_id = slot.oai_resp_reasoning_id;
|
|
res->oai_resp_message_id = slot.oai_resp_message_id;
|
|
res->n_decoded = slot.n_decoded;
|
|
res->n_prompt_tokens_cache = slot.n_prompt_tokens_cache;
|
|
res->anthropic_thinking_block_started = slot.anthropic_thinking_block_started;
|
|
res->anthropic_text_block_started = slot.anthropic_text_block_started;
|
|
res->n_prompt_tokens = slot.n_prompt_tokens;
|
|
res->oaicompat_model = slot.task->params.oaicompat_model;
|
|
res->data = json{
|
|
{"content", !slot.params.stream ? slot.generated_text : ""},
|
|
{"generated_text", slot.generated_text}, // Always include full text for finish_reason logic
|
|
{"id_slot", slot.id},
|
|
{"stop", true},
|
|
{"model", params_base.model_alias},
|
|
{"tokens_predicted", slot.n_decoded},
|
|
{"tokens_evaluated", slot.n_prompt_tokens},
|
|
{"generation_settings", get_formated_generation(slot)},
|
|
{"prompt", slot.prompt},
|
|
{"truncated", slot.truncated},
|
|
{"stopped_eos", slot.stopped_eos},
|
|
{"stopped_word", slot.stopped_word},
|
|
{"stopped_limit", slot.stopped_limit},
|
|
{"stopping_word", slot.stopping_word},
|
|
{"tokens_cached", slot.n_past},
|
|
{"timings", slot.get_formated_timings()},
|
|
//{"oaicompat_chat_format", slot.params.oaicompat_chat_format},
|
|
};
|
|
|
|
// populate res->probs_output
|
|
if (slot.sparams.n_probs > 0) {
|
|
res->probs_output = std::vector<completion_token_output>(
|
|
slot.generated_token_probs.begin(),
|
|
slot.generated_token_probs.end());
|
|
res->data["completion_probabilities"] = probs_vector_to_json(ctx, res->probs_output);
|
|
}
|
|
|
|
if (slot.oaicompat) {
|
|
res->data["oaicompat_token_ctr"] = slot.n_decoded;
|
|
res->data["model"] = slot.oaicompat_model;
|
|
}
|
|
|
|
queue_results.send(std::move(res));
|
|
}
|
|
|
|
void server_context::send_embedding(const server_slot& slot, const llama_batch& batch) {
|
|
auto res = std::make_unique<server_task_result_embd>();
|
|
res->id = slot.task->id;
|
|
res->index = slot.task->index;
|
|
res->server_task_result::index = slot.task->index;
|
|
res->n_tokens = slot.prompt_tokens.size();
|
|
res->oaicompat = slot.task->params.oaicompat;
|
|
|
|
const int n_embd = llama_model_n_embd(model);
|
|
|
|
std::vector<float> embd_res(n_embd, 0.0f);
|
|
|
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
|
continue;
|
|
}
|
|
|
|
const float* embd = nullptr;
|
|
if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
|
embd = llama_get_embeddings_ith(ctx, i);
|
|
}
|
|
else {
|
|
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
|
}
|
|
|
|
if (embd == nullptr) {
|
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
|
|
|
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
|
continue;
|
|
}
|
|
|
|
// normalize only when there is pooling
|
|
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
|
common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize);
|
|
res->embedding.push_back(embd_res);
|
|
break;
|
|
}
|
|
|
|
res->embedding.emplace_back(embd, embd + n_embd);
|
|
}
|
|
queue_results.send(std::move(res));
|
|
}
|
|
|
|
void server_context::apply_server_biases(server_slot& slot) {
|
|
auto& server_biases = slot.ctx_sampling->server_biases;
|
|
|
|
if (slot.allow_idx < slot.allow_biasess.size()) {
|
|
server_biases = &slot.allow_biasess[slot.allow_idx];
|
|
} else {
|
|
server_biases = nullptr;
|
|
}
|
|
}
|
|
|
|
void server_context::request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs) {
|
|
server_task task;
|
|
task.id = id_task;
|
|
task.id_multi = id_multi;
|
|
task.id_target = 0;
|
|
task.data = std::move(data);
|
|
task.infill = infill;
|
|
task.embedding = embedding;
|
|
task.type = SERVER_TASK_TYPE_COMPLETION;
|
|
task.tokens = std::move(inputs);
|
|
// when a completion task's prompt array is not a singleton, we split it into multiple requests
|
|
// otherwise, it's a single-prompt task, we actually queue it
|
|
// if there's numbers in the prompt array it will be treated as an array of tokens
|
|
if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
|
|
bool numbers = false;
|
|
for (const auto& e : task.data.at("prompt")) {
|
|
if (e.is_number()) {
|
|
numbers = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// NOTE: split_multiprompt_task() does not handle a mix of strings and numbers,
|
|
// it will completely stall the server. I don't know where the bug for this is.
|
|
//
|
|
// if there are numbers, it needs to be treated like a single prompt,
|
|
// queue_tasks handles a mix of strings and numbers just fine.
|
|
if (numbers) {
|
|
queue_tasks.post(std::move(task));
|
|
}
|
|
else {
|
|
split_multiprompt_task(id_task, task);
|
|
}
|
|
}
|
|
else {
|
|
queue_tasks.post(std::move(task));
|
|
}
|
|
}
|
|
|
|
void server_context::request_cancel(int id_task) {
|
|
server_task task;
|
|
task.type = SERVER_TASK_TYPE_CANCEL;
|
|
task.id_target = id_task;
|
|
|
|
queue_tasks.post(std::move(task));
|
|
}
|
|
|
|
void server_context::split_multiprompt_task(int id_multi, server_task& multiprompt_task) {
|
|
const int prompt_count = multiprompt_task.data.at("prompt").size();
|
|
if (prompt_count <= 1) {
|
|
send_error(multiprompt_task, "error while handling multiple prompts");
|
|
return;
|
|
}
|
|
|
|
// generate all the ID for subtask
|
|
std::vector<int> subtask_ids(prompt_count);
|
|
for (int i = 0; i < prompt_count; i++) {
|
|
subtask_ids[i] = queue_tasks.get_new_id();
|
|
}
|
|
|
|
// queue up the multitask so we can track its subtask progression
|
|
queue_tasks.add_multitask(id_multi, subtask_ids);
|
|
|
|
// add subtasks
|
|
for (int i = 0; i < prompt_count; i++) {
|
|
json subtask_data = multiprompt_task.data;
|
|
subtask_data["prompt"] = subtask_data.at("prompt")[i];
|
|
|
|
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
|
request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding,
|
|
std::move(multiprompt_task.tokens));
|
|
}
|
|
}
|
|
|
|
|
|
|
|
static size_t save_checkpoints_to_file(const std::string & filename, const std::list<server_prompt_checkpoint> & checkpoints) {
|
|
if (checkpoints.size() == 0) {
|
|
return 0;
|
|
}
|
|
std::ofstream file(filename, std::ios::binary);
|
|
uint32_t magic = LLAMA_STATE_SEQ_MAGIC;
|
|
file.write(reinterpret_cast<const char *>(&magic), sizeof(magic));
|
|
uint32_t version = LLAMA_STATE_SEQ_VERSION;
|
|
file.write(reinterpret_cast<const char *>(&version), sizeof(version));
|
|
size_t count = checkpoints.size();
|
|
file.write(reinterpret_cast<const char *>(&count), sizeof(count));
|
|
|
|
for (const auto & checkpoint : checkpoints) {
|
|
file.write(reinterpret_cast<const char *>(&checkpoint.pos_min), sizeof(checkpoint.pos_min));
|
|
file.write(reinterpret_cast<const char *>(&checkpoint.pos_max), sizeof(checkpoint.pos_max));
|
|
file.write(reinterpret_cast<const char *>(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt));
|
|
file.write(reinterpret_cast<const char *>(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt));
|
|
size_t data_len = checkpoint.data.size();
|
|
file.write(reinterpret_cast<const char *>(&data_len), sizeof(data_len));
|
|
if (data_len > 0) {
|
|
file.write(reinterpret_cast<const char *>(checkpoint.data.data()), data_len * sizeof(uint8_t));
|
|
}
|
|
}
|
|
size_t pos = file.tellp();
|
|
file.close();
|
|
return pos;
|
|
}
|
|
|
|
static size_t load_checkpoints_from_file(const std::string & filename, std::list<server_prompt_checkpoint> & checkpoints) {
|
|
std::ifstream file(filename, std::ios::binary);
|
|
if (!file.is_open()) {
|
|
return 0;
|
|
}
|
|
checkpoints.clear();
|
|
// version checks
|
|
{
|
|
uint32_t magic;
|
|
file.read(reinterpret_cast<char *>(&magic), sizeof(magic));
|
|
uint32_t version;
|
|
file.read(reinterpret_cast<char *>(&version), sizeof(version));
|
|
|
|
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
|
|
LLAMA_LOG_ERROR("%s: unknown (magic, version) for checkpoint file: %08x, %08x\n", __func__, magic, version);
|
|
return 0;
|
|
}
|
|
}
|
|
// load the checkpoints
|
|
{
|
|
size_t count;
|
|
file.read(reinterpret_cast<char *>(&count), sizeof(count));
|
|
|
|
for (int i = 0; i < count; i++) {
|
|
server_prompt_checkpoint checkpoint;
|
|
file.read(reinterpret_cast<char *>(&checkpoint.pos_min), sizeof(checkpoint.pos_min));
|
|
file.read(reinterpret_cast<char *>(&checkpoint.pos_max), sizeof(checkpoint.pos_max));
|
|
file.read(reinterpret_cast<char *>(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt));
|
|
file.read(reinterpret_cast<char *>(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt));
|
|
|
|
size_t data_len;
|
|
file.read(reinterpret_cast<char *>(&data_len), sizeof(data_len));
|
|
if (data_len > 0) {
|
|
checkpoint.data.resize(data_len);
|
|
file.read(reinterpret_cast<char *>(checkpoint.data.data()), data_len * sizeof(uint8_t));
|
|
}
|
|
checkpoints.push_back(checkpoint);
|
|
}
|
|
}
|
|
size_t pos = file.tellg();
|
|
file.close();
|
|
return pos;
|
|
}
|
|
|
|
static size_t save_server_tokens_to_file(const std::string & filename, const server_tokens & tokens) {
|
|
std::ofstream file(filename, std::ios::binary);
|
|
json token_json = tokens.to_json();
|
|
token_json["magic"] = LLAMA_SERVER_MAGIC;
|
|
token_json["version"] = LLAMA_SERVER_VERSION;
|
|
size_t pos = 0;
|
|
if (file.is_open()) {
|
|
file << token_json;
|
|
pos = file.tellp();
|
|
file.close();
|
|
}
|
|
return pos;
|
|
}
|
|
|
|
static size_t load_server_tokens_from_file(const std::string & filename, server_tokens & tokens) {
|
|
std::ifstream file(filename, std::ios::binary);
|
|
if (!file.is_open()) {
|
|
return 0;
|
|
}
|
|
size_t pos = 0;
|
|
json token_json;
|
|
if (file.is_open()) {
|
|
file >> token_json;
|
|
pos = file.tellg();
|
|
file.close();
|
|
}
|
|
uint32_t magic = token_json.value<uint32_t>("magic", 0);
|
|
uint32_t version = token_json.value<uint32_t>("version", 0);
|
|
if (magic != LLAMA_SERVER_MAGIC || version != LLAMA_SERVER_VERSION) {
|
|
LLAMA_LOG_ERROR("%s: unknown (magic, version) for token file: %08x, %08x\n", __func__, magic, version);
|
|
return 0;
|
|
}
|
|
tokens.from_json(token_json);
|
|
|
|
return pos;
|
|
}
|
|
|
|
void server_context::process_single_task(server_task&& task) {
|
|
switch (task.type) {
|
|
case SERVER_TASK_TYPE_COMPLETION:
|
|
case SERVER_TASK_TYPE_INFILL:
|
|
case SERVER_TASK_TYPE_EMBEDDING:
|
|
case SERVER_TASK_TYPE_RERANK:
|
|
{
|
|
const int id_slot = json_value(task.data, "id_slot", -1);
|
|
|
|
server_slot* slot;
|
|
|
|
if (id_slot != -1) {
|
|
slot = get_slot_by_id(id_slot);
|
|
}
|
|
else {
|
|
slot = get_available_slot(task);
|
|
}
|
|
|
|
if (slot == nullptr) {
|
|
// if no slot is available, we defer this task for processing later
|
|
LOG_VERBOSE("no slot is available", { {"id_task", task.id} });
|
|
queue_tasks.defer(std::move(task));
|
|
break;
|
|
}
|
|
if (!slot->available()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} });
|
|
queue_tasks.defer(std::move(task));
|
|
break;
|
|
}
|
|
|
|
if (task.data.contains("system_prompt")) {
|
|
std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
|
|
system_prompt_set(sys_prompt);
|
|
|
|
for (server_slot& slot : slots) {
|
|
slot.n_past = 0;
|
|
slot.n_past_se = 0;
|
|
}
|
|
}
|
|
|
|
slot->reset();
|
|
|
|
slot->id_task = task.id;
|
|
slot->id_multi = task.id_multi;
|
|
slot->infill = task.infill;
|
|
slot->embedding = task.embedding;
|
|
|
|
if (!launch_slot_with_task(*slot, task)) {
|
|
LOG_ERROR("error while launching slot", task.data);
|
|
break;
|
|
}
|
|
} break;
|
|
case SERVER_TASK_TYPE_CANCEL:
|
|
{
|
|
// release slot linked with the task id
|
|
for (auto& slot : slots) {
|
|
if (slot.id_task == task.id_target) {
|
|
slot.release();
|
|
break;
|
|
}
|
|
}
|
|
} break;
|
|
case SERVER_TASK_TYPE_NEXT_RESPONSE:
|
|
{
|
|
// do nothing
|
|
} break;
|
|
case SERVER_TASK_TYPE_METRICS:
|
|
{
|
|
json slots_data = json::array();
|
|
|
|
int n_idle_slots = 0;
|
|
int n_processing_slots = 0;
|
|
|
|
for (server_slot& slot : slots) {
|
|
json slot_data = get_formated_generation(slot);
|
|
slot_data["id"] = slot.id;
|
|
slot_data["id_task"] = slot.id_task;
|
|
slot_data["state"] = slot.state;
|
|
slot_data["prompt"] = slot.prompt;
|
|
slot_data["next_token"] = {
|
|
{"has_next_token", slot.has_next_token},
|
|
{"n_remain", slot.n_remaining},
|
|
{"n_decoded", slot.n_decoded},
|
|
{"stopped_eos", slot.stopped_eos},
|
|
{"stopped_word", slot.stopped_word},
|
|
{"stopped_limit", slot.stopped_limit},
|
|
{"stopping_word", slot.stopping_word},
|
|
};
|
|
|
|
if (slot_data["state"] == SLOT_STATE_IDLE) {
|
|
n_idle_slots++;
|
|
}
|
|
else {
|
|
n_processing_slots++;
|
|
}
|
|
|
|
slots_data.push_back(slot_data);
|
|
}
|
|
LOG_INFO("slot data", {
|
|
{"id_task", task.id},
|
|
{"n_idle_slots", n_idle_slots},
|
|
{"n_processing_slots", n_processing_slots}
|
|
});
|
|
|
|
LOG_VERBOSE("slot data", {
|
|
{"id_task", task.id},
|
|
{"n_idle_slots", n_idle_slots},
|
|
{"n_processing_slots", n_processing_slots},
|
|
{"slots", slots_data}
|
|
});
|
|
|
|
server_task_result res;
|
|
res.id = task.id;
|
|
res.id_multi = task.id_multi;
|
|
res.stop = true;
|
|
res.error = false;
|
|
res.data = {
|
|
{ "idle", n_idle_slots },
|
|
{ "processing", n_processing_slots },
|
|
{ "deferred", queue_tasks.queue_tasks_deferred.size() },
|
|
{ "t_start", metrics.t_start},
|
|
|
|
{ "n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total},
|
|
{ "t_tokens_generation_total", metrics.t_tokens_generation_total},
|
|
{ "n_tokens_predicted_total", metrics.n_tokens_predicted_total},
|
|
{ "t_prompt_processing_total", metrics.t_prompt_processing_total},
|
|
|
|
{ "n_prompt_tokens_processed", metrics.n_prompt_tokens_processed},
|
|
{ "t_prompt_processing", metrics.t_prompt_processing},
|
|
{ "n_tokens_predicted", metrics.n_tokens_predicted},
|
|
{ "t_tokens_generation", metrics.t_tokens_generation},
|
|
|
|
{ "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)},
|
|
{ "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)},
|
|
|
|
{ "slots", slots_data },
|
|
};
|
|
|
|
if (json_value(task.data, "reset_bucket", false)) {
|
|
metrics.reset_bucket();
|
|
}
|
|
queue_results.send(res);
|
|
} break;
|
|
case SERVER_TASK_TYPE_SLOT_SAVE:
|
|
{
|
|
int id_slot = task.data.at("id_slot");
|
|
server_slot* slot = get_slot_by_id(id_slot);
|
|
if (slot == nullptr) {
|
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
if (!slot->available()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} });
|
|
queue_tasks.defer(std::move(task));
|
|
break;
|
|
}
|
|
|
|
const size_t token_count = slot->cache_tokens.size();
|
|
const int64_t t_start = ggml_time_us();
|
|
|
|
std::string filename = task.data.at("filename");
|
|
std::string filepath = task.data.at("filepath");
|
|
save_server_tokens_to_file(filepath+".tokens.json", slot->cache_tokens);
|
|
size_t saved = save_checkpoints_to_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints);
|
|
|
|
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
|
|
|
|
const int64_t t_end = ggml_time_us();
|
|
const double t_save_ms = (t_end - t_start) / 1000.0;
|
|
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.stop = true;
|
|
result.error = false;
|
|
result.data = json{
|
|
{ "id_slot", id_slot },
|
|
{ "filename", filename },
|
|
{ "n_saved", token_count }, // tokens saved
|
|
{ "n_written", nwrite + saved }, // bytes written
|
|
{ "timings", {
|
|
{ "save_ms", t_save_ms }
|
|
} }
|
|
};
|
|
queue_results.send(result);
|
|
} break;
|
|
case SERVER_TASK_TYPE_SLOT_RESTORE:
|
|
{
|
|
int id_slot = task.data.at("id_slot");
|
|
server_slot* slot = get_slot_by_id(id_slot);
|
|
if (slot == nullptr) {
|
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
if (!slot->available()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} });
|
|
queue_tasks.defer(std::move(task));
|
|
break;
|
|
}
|
|
const int64_t t_start = ggml_time_us();
|
|
|
|
std::string filename = task.data.at("filename");
|
|
std::string filepath = task.data.at("filepath");
|
|
|
|
slot->cache_tokens.resize(slot->n_ctx);
|
|
size_t token_count = 0;
|
|
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
|
|
if (nread == 0) {
|
|
slot->cache_tokens.resize(0);
|
|
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
load_server_tokens_from_file(filepath+".tokens.json", slot->cache_tokens);
|
|
size_t loaded = load_checkpoints_from_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints);
|
|
|
|
const int64_t t_end = ggml_time_us();
|
|
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
|
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.stop = true;
|
|
result.error = false;
|
|
result.data = json{
|
|
{ "id_slot", id_slot },
|
|
{ "filename", filename },
|
|
{ "n_restored", token_count }, // tokens restored
|
|
{ "n_read", nread }, // bytes read
|
|
{ "timings", {
|
|
{ "restore_ms", t_restore_ms }
|
|
} }
|
|
};
|
|
queue_results.send(result);
|
|
} break;
|
|
case SERVER_TASK_TYPE_SLOT_ERASE:
|
|
{
|
|
int id_slot = task.data.at("id_slot");
|
|
server_slot* slot = get_slot_by_id(id_slot);
|
|
if (slot == nullptr) {
|
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
if (!slot->available()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} });
|
|
queue_tasks.defer(std::move(task));
|
|
break;
|
|
}
|
|
// Erase token cache
|
|
const size_t n_erased = slot->cache_tokens.size();
|
|
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
|
|
slot->cache_tokens.keep_first(0);
|
|
//slot->cache_tokens.clear();
|
|
slot->server_cached_prompt.checkpoints.clear();
|
|
slot->server_cached_prompt.data.clear();
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.stop = true;
|
|
result.error = false;
|
|
result.data = json{
|
|
{ "id_slot", id_slot },
|
|
{ "n_erased", n_erased }
|
|
};
|
|
queue_results.send(result);
|
|
} break;
|
|
case SERVER_TASK_TYPE_SET_LORA:
|
|
{
|
|
llama_lora_adapters_apply(ctx, lora_adapters);
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.stop = true;
|
|
result.error = false;
|
|
result.data = json{ { "success", true } };
|
|
queue_results.send(result);
|
|
} break;
|
|
case SERVER_TASK_TYPE_LOAD_CONTROL_VECTOR:
|
|
{
|
|
// Load control vector from file
|
|
std::string path = task.data.at("path");
|
|
float scale = task.data.value("scale", 1.0f);
|
|
int32_t layer_start = task.data.value("layer_start", 1);
|
|
int32_t layer_end = task.data.value("layer_end", llama_n_layer(model));
|
|
|
|
// Check if already loaded
|
|
int cv_id = -1;
|
|
for (size_t i = 0; i < control_vectors.size(); i++) {
|
|
if (control_vectors[i].path == path) {
|
|
control_vectors[i].scale = scale;
|
|
control_vectors[i].layer_start = layer_start;
|
|
control_vectors[i].layer_end = layer_end;
|
|
cv_id = i;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (cv_id == -1) {
|
|
control_vector_container new_cv;
|
|
new_cv.path = path;
|
|
new_cv.scale = scale;
|
|
new_cv.layer_start = layer_start;
|
|
new_cv.layer_end = layer_end;
|
|
new_cv.applied = false;
|
|
|
|
// Load the control vector data
|
|
llama_control_vector_load_info load_info;
|
|
load_info.fname = path;
|
|
load_info.strength = 1.0f; // Don't pre-scale here, we'll scale when applying
|
|
|
|
std::vector<llama_control_vector_load_info> load_infos = { load_info };
|
|
new_cv.data = llama_control_vector_load(load_infos);
|
|
|
|
if (new_cv.data.n_embd == -1) {
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = true;
|
|
result.data = json{{ "success", false }, { "error", "Failed to load control vector from " + path }};
|
|
queue_results.send(result);
|
|
break;
|
|
}
|
|
|
|
// Validate dimension to prevent heap corruption
|
|
if (new_cv.data.n_embd != llama_model_n_embd(model)) {
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = true;
|
|
result.data = json{{ "success", false },
|
|
{ "error", "Vector dimension mismatch" }};
|
|
queue_results.send(result);
|
|
break;
|
|
}
|
|
|
|
control_vectors.push_back(new_cv);
|
|
|
|
cv_id = control_vectors.size() - 1;
|
|
}
|
|
|
|
// Auto-apply control vectors after loading
|
|
if (!apply_control_vectors_internal()) {
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = true;
|
|
result.data = json{{ "success", false }, { "error", "Failed to apply control vectors" }};
|
|
queue_results.send(result);
|
|
break;
|
|
}
|
|
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = false;
|
|
result.data = json{{ "success", true }, { "id", cv_id }};
|
|
queue_results.send(result);
|
|
} break;
|
|
case SERVER_TASK_TYPE_UNLOAD_CONTROL_VECTOR:
|
|
{
|
|
// Validate that "id" field exists and is a number
|
|
if (!task.data.contains("id") || task.data["id"].is_null() || !task.data["id"].is_number()) {
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = true;
|
|
result.data = json{{ "success", false }, { "error", "Missing or invalid 'id' field" }};
|
|
queue_results.send(result);
|
|
break;
|
|
}
|
|
|
|
int id = task.data.at("id");
|
|
|
|
if (id < 0 || id >= (int)control_vectors.size()) {
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = true;
|
|
result.data = json{{ "success", false }, { "error", "Invalid control vector ID" }};
|
|
queue_results.send(result);
|
|
break;
|
|
}
|
|
|
|
// Remove the control vector from the list
|
|
control_vectors.erase(control_vectors.begin() + id);
|
|
|
|
// Reapply remaining control vectors
|
|
if (!apply_control_vectors_internal()) {
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = true;
|
|
result.data = json{{ "success", false }, { "error", "Failed to apply control vectors" }};
|
|
queue_results.send(result);
|
|
break;
|
|
}
|
|
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = false;
|
|
result.data = json{{ "success", true }};
|
|
queue_results.send(result);
|
|
} break;
|
|
case SERVER_TASK_TYPE_SET_CONTROL_VECTOR:
|
|
{
|
|
if (!apply_control_vectors_internal()) {
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = true;
|
|
result.data = json{{ "success", false }, { "error", "Failed to apply control vectors" }};
|
|
queue_results.send(result);
|
|
break;
|
|
}
|
|
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.error = false;
|
|
result.data = json{{ "success", true }};
|
|
queue_results.send(result);
|
|
} break;
|
|
}
|
|
}
|
|
|
|
bool server_context::apply_control_vectors_internal() {
|
|
llama_control_vector_data combined_cv = { -1, {} };
|
|
|
|
// Check if we have anything to apply
|
|
bool any_active = false;
|
|
for (const auto& cv : control_vectors) {
|
|
if (cv.scale != 0.0f) {
|
|
any_active = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (!any_active) {
|
|
// Clear control vectors if nothing is active
|
|
llama_control_vector_apply(ctx, nullptr, 0, 0, 0, 0);
|
|
return true;
|
|
}
|
|
|
|
// Aggregate control vectors with scaling
|
|
for (auto& cv : control_vectors) {
|
|
if (cv.scale == 0.0f) {
|
|
cv.applied = false;
|
|
continue;
|
|
}
|
|
|
|
if (combined_cv.n_embd == -1) {
|
|
combined_cv.n_embd = cv.data.n_embd;
|
|
combined_cv.data.resize(cv.data.data.size(), 0.0f);
|
|
}
|
|
|
|
for (size_t i = 0; i < cv.data.data.size(); i++) {
|
|
combined_cv.data[i] += cv.data.data[i] * cv.scale;
|
|
}
|
|
cv.applied = true;
|
|
}
|
|
|
|
// Apply combined control vector
|
|
if (combined_cv.n_embd != -1 && !combined_cv.data.empty()) {
|
|
int32_t min_layer_start = INT32_MAX;
|
|
int32_t max_layer_end = 0;
|
|
|
|
for (const auto& cv : control_vectors) {
|
|
if (cv.scale != 0.0f) {
|
|
min_layer_start = std::min(min_layer_start, cv.layer_start);
|
|
max_layer_end = std::max(max_layer_end, cv.layer_end);
|
|
}
|
|
}
|
|
|
|
int err = llama_control_vector_apply(ctx,
|
|
combined_cv.data.data(),
|
|
combined_cv.data.size(),
|
|
combined_cv.n_embd,
|
|
min_layer_start,
|
|
max_layer_end);
|
|
return (err == 0);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void server_context::on_finish_multitask(const server_task_multi& multitask) {
|
|
// all subtasks done == multitask is done
|
|
server_task_result result;
|
|
result.id = multitask.id;
|
|
result.stop = true;
|
|
result.error = false;
|
|
|
|
// collect json results into one json result
|
|
std::vector<json> result_jsons;
|
|
for (const auto& subres : multitask.results) {
|
|
result_jsons.push_back(subres.data);
|
|
result.error = result.error && subres.error;
|
|
}
|
|
result.data = json{
|
|
{ "results", result_jsons }
|
|
};
|
|
|
|
queue_results.send(result);
|
|
}
|
|
|
|
void server_context::print_tokens(const server_tokens& prompt, const server_tokens& cache, size_t start1, size_t start2, size_t length) {
|
|
if (cache.size() > start2) {
|
|
LLAMA_LOG_INFO("cache : %s\n", cache.detokenize(ctx, true, start2, length).c_str());
|
|
}
|
|
if (prompt.size() > start1) {
|
|
LLAMA_LOG_INFO("prompt: %s\n", prompt.detokenize(ctx, true, start1, length).c_str());
|
|
}
|
|
|
|
}
|
|
|
|
void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard) {
|
|
auto kv_keep = slot.cache_tokens.pos_next(n_keep);
|
|
auto kv_discard = slot.cache_tokens.pos_next(n_keep + n_discard) - kv_keep;
|
|
auto kv_past = slot.cache_tokens.pos_next(slot.n_past);
|
|
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
|
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
|
llama_kv_cache_seq_rm(ctx, slot.id, kv_keep, kv_keep + kv_discard);
|
|
llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard);
|
|
if (slot.has_mtp && slot.spec) {
|
|
common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past);
|
|
}
|
|
if (slot.params.cache_prompt) {
|
|
slot.cache_tokens.discard_n_tokens(n_keep, n_discard);
|
|
}
|
|
}
|
|
|
|
|
|
inline static bool tokens_support_context_shift(const server_tokens & tokens, int32_t n_keep,
|
|
int32_t n_discard) {
|
|
bool can_shift = !tokens.has_mtmd;
|
|
if (tokens.has_mtmd) {
|
|
can_shift = true;
|
|
if (n_keep > 0 && n_keep<= tokens.n_tokens()) {
|
|
can_shift = tokens[n_keep - 1] != LLAMA_TOKEN_NULL;
|
|
}
|
|
if (n_discard + n_keep > 0 && n_discard + n_keep <= tokens.n_tokens()) {
|
|
can_shift = can_shift && tokens[n_discard + n_keep - 1] != LLAMA_TOKEN_NULL;
|
|
}
|
|
}
|
|
return can_shift;
|
|
}
|
|
|
|
inline static void adjust_n_to_support_context_shift(const server_tokens & tokens, int32_t & n_keep,
|
|
int32_t & n_discard) {
|
|
if (!tokens.has_mtmd) {
|
|
return;
|
|
}
|
|
if (n_keep > 0 && n_keep <= tokens.n_tokens()) {
|
|
while (tokens[n_keep - 1] == LLAMA_TOKEN_NULL) {
|
|
n_keep--;
|
|
if (n_keep<1 || n_keep>tokens.size()) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if (n_discard + n_keep > 0 && n_discard + n_keep <= tokens.n_tokens()) {
|
|
while (tokens[n_discard + n_keep - 1] == LLAMA_TOKEN_NULL) {
|
|
n_discard++;
|
|
if (n_discard + n_keep<1 || n_discard + n_keep>tokens.size()) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
|
|
// convert keep first few and discard next tokens in a to b
|
|
void server_context::context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep,
|
|
int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact) {
|
|
|
|
common_prefix ctx_keep_prefix = a.get_common_prefix_first_n(ctx, b, n_keep, exact);
|
|
common_prefix ctx_total_discard_prefix = a.get_common_prefix_first_n(ctx, b, n_discard + n_keep, exact);
|
|
// only if there is enough common token
|
|
int32_t discard_offset = ctx_total_discard_prefix.first - (n_discard + n_keep);
|
|
int32_t keep_offset = ctx_keep_prefix.first - n_keep;
|
|
n_kept = ctx_keep_prefix.second - keep_offset;
|
|
n_discarded = ctx_total_discard_prefix.second - ctx_keep_prefix.second - discard_offset;
|
|
if (n_kept < 0) {
|
|
n_kept = n_keep;
|
|
}
|
|
if (n_discarded < 0) {
|
|
n_discarded = n_discard;
|
|
}
|
|
}
|
|
|
|
void server_context::context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact) {
|
|
int n_keep = std::max(0, slot.params.n_keep + add_bos_token);
|
|
const int n_left = slot.n_ctx - n_keep;
|
|
int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
|
adjust_n_to_support_context_shift(slot.prompt_tokens, n_keep, n_discard);
|
|
if (n_discard<=0 || !tokens_support_context_shift(slot.prompt_tokens, n_keep, n_discard)) {
|
|
return;
|
|
}
|
|
int n_discard_prompt = 0;
|
|
// we still need to truncate input since we have not discarded enough tokens
|
|
while (slot.n_prompt_tokens - slot.n_discarded_prompt >= slot.n_ctx) {
|
|
slot.n_discarded_prompt = slot.n_discarded_prompt + n_discard;
|
|
n_discard_prompt = n_discard_prompt + n_discard;
|
|
}
|
|
|
|
// Handle mistokenization between prompt and cache during context shift
|
|
//
|
|
int32_t n_discard_cache = n_discard_prompt;
|
|
int32_t n_kept = n_keep;
|
|
slot.prompt_tokens.discard_n_tokens(n_keep, slot.n_discarded_prompt - n_discard_prompt);
|
|
if (n_discard_prompt > 0) {
|
|
context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep,
|
|
n_discard, n_kept, n_discard_cache, exact);
|
|
}
|
|
|
|
int n_discard_cache_max = std::max((int32_t)slot.cache_tokens.size() - n_kept, 0);
|
|
n_discard_cache = std::min(n_discard_cache, n_discard_cache_max);
|
|
// discard matching tokens from cache and kv cache to avoid reprocessing the prompt
|
|
if (n_discard_cache > 0) {
|
|
discard_n_kv_and_cache_tokens(ctx, slot, n_kept, n_discard_cache);
|
|
}
|
|
// discard extra tokens from prompts
|
|
slot.n_kept_prompt = n_keep;
|
|
slot.prompt_tokens.discard_n_tokens(n_keep, n_discard_prompt);
|
|
slot.n_prompt_tokens = slot.prompt_tokens.size();
|
|
}
|
|
|
|
void server_context::release_slots()
|
|
{
|
|
for (auto& slot : slots) {
|
|
if (slot.command == SLOT_COMMAND_RELEASE) {
|
|
slot.state = SLOT_STATE_IDLE;
|
|
slot.command = SLOT_COMMAND_NONE;
|
|
slot.t_last_used = ggml_time_us();
|
|
|
|
LOG_INFO("slot released", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_ctx", n_ctx},
|
|
{"n_past", slot.n_past},
|
|
{"n_system_tokens", system_tokens.size()},
|
|
{"n_cache_tokens", slot.cache_tokens.size()},
|
|
{"truncated", slot.truncated}
|
|
});
|
|
|
|
queue_tasks.notify_slot_changed();
|
|
}
|
|
}
|
|
}
|
|
|
|
bool server_context::slots_idle(){
|
|
bool all_idle = true;
|
|
for (auto& slot : slots) {
|
|
if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) {
|
|
all_idle = false;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (all_idle) {
|
|
LOG_INFO("all slots are idle", {});
|
|
if (system_prompt.empty() && clean_kv_cache) {
|
|
kv_cache_clear();
|
|
}
|
|
all_idle = true;
|
|
}
|
|
return all_idle;
|
|
}
|
|
|
|
void server_context::context_shift() {
|
|
for (server_slot& slot : slots) {
|
|
if (slot.ga_n == 1) {
|
|
if (slot.is_processing() && (int)system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
|
|
if (!params_base.ctx_shift) {
|
|
// this check is redundant (for good)
|
|
// we should never get here, because generation should already stopped in process_token()
|
|
slot.print_timings();
|
|
slot.release();
|
|
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
|
|
continue;
|
|
}
|
|
// Shift context
|
|
int n_keep = slot.params.n_keep < 0 ? slot.prompt_tokens.size() : slot.params.n_keep;
|
|
if (add_bos_token) {
|
|
n_keep += 1;
|
|
}
|
|
n_keep = std::min(slot.n_ctx - 4, n_keep);
|
|
|
|
const int32_t n_left = (int)system_tokens.size() + slot.n_past - n_keep;
|
|
int32_t n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
|
int32_t n_kept;
|
|
int32_t n_discard_cache;
|
|
adjust_n_to_support_context_shift(slot.cache_tokens, n_keep, n_discard);
|
|
if (n_discard > 0 && tokens_support_context_shift(slot.cache_tokens, n_keep, n_discard)) {
|
|
context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep,
|
|
n_discard, n_kept, n_discard_cache);
|
|
LOG_INFO("slot context shift", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_keep", n_keep},
|
|
{"n_left", n_left},
|
|
{"n_discard", n_discard},
|
|
{"n_ctx", n_ctx},
|
|
{"n_past", slot.n_past},
|
|
{"n_system_tokens", system_tokens.size()},
|
|
{"n_cache_tokens", slot.cache_tokens.size()}
|
|
});
|
|
slot.n_discarded_prompt = slot.n_discarded_prompt + n_discard;
|
|
slot.n_kept_prompt = n_keep;
|
|
discard_n_kv_and_cache_tokens(ctx, slot, n_kept, n_discard_cache);
|
|
slot.n_past -= n_discard_cache;
|
|
slot.truncated = true;
|
|
}
|
|
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void server_context::add_sampled_tokens() {
|
|
for (auto& slot : slots) {
|
|
slot.released = false;
|
|
if (slot.state == SLOT_STATE_IDLE) {
|
|
continue;
|
|
}
|
|
|
|
// generate draft tokens in speculative decoding mode
|
|
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
|
// perform the speculative drafting for all sequences at the same time in a single batch
|
|
const int n_draft_max_pre = slot.get_n_draft_max();
|
|
if (n_draft_max_pre > 0) {
|
|
if (mctx) {
|
|
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
|
GGML_ABORT("not supported by multimodal");
|
|
}
|
|
|
|
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
|
|
|
|
auto & params_spec = slot.params.speculative;
|
|
|
|
if (slot.has_mtp) {
|
|
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
|
|
llama_context * hs_ctx = mtp_ctx ? mtp_ctx : ctx;
|
|
if (!slot.mtp_hidden_state.empty()) {
|
|
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
|
const int n_hidden = slot.mtp_hidden_state.size() / n_embd;
|
|
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data() + (n_hidden - 1) * n_embd);
|
|
} else {
|
|
LOG_ERROR("MTP hidden state is empty during speculation", {});
|
|
const float* emb_neg1 = llama_get_embeddings_ith(ctx, -1);
|
|
if (emb_neg1) {
|
|
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
|
slot.mtp_hidden_state.resize(n_embd);
|
|
memcpy(slot.mtp_hidden_state.data(), emb_neg1, n_embd * sizeof(float));
|
|
llama_set_draft_input_hidden_state(hs_ctx, slot.mtp_hidden_state.data());
|
|
}
|
|
}
|
|
}
|
|
|
|
llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
|
|
|
|
const int n_draft_max = slot.get_n_draft_max();
|
|
|
|
if (draft.size() > (size_t)n_draft_max) {
|
|
if (slot.params.speculative.autotune) {
|
|
// expected near end-of-response when autotune shrinks n_max
|
|
SLT_DBG(slot, "draft size %d exceeds max %d, truncating\n", (int)draft.size(), n_draft_max);
|
|
} else {
|
|
SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int)draft.size(), n_draft_max);
|
|
}
|
|
draft.resize(n_draft_max);
|
|
}
|
|
|
|
// add the sampled token to the batch
|
|
slot.i_batch_dft.push_back(batch.n_tokens);
|
|
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
|
slot.cache_tokens.push_back(slot.sampled);
|
|
|
|
if (slot.params.speculative.n_min > (int)draft.size()) {
|
|
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min);
|
|
// fallback to normal decoding
|
|
slot.i_batch = slot.i_batch_dft[0];
|
|
slot.drafted.clear();
|
|
slot.i_batch_dft.clear();
|
|
}
|
|
else {
|
|
// keep track of total number of drafted tokens tested
|
|
slot.n_draft_total += draft.size();
|
|
|
|
// add all drafted tokens to the batch
|
|
for (size_t i = 0; i < draft.size(); i++) {
|
|
slot.i_batch_dft.push_back(batch.n_tokens);
|
|
common_batch_add(batch, draft[i], slot.cache_tokens.pos_next(), { slot.id }, true);
|
|
slot.cache_tokens.push_back(draft[i]);
|
|
}
|
|
slot.drafted = std::move(draft);
|
|
}
|
|
}
|
|
else {
|
|
// no speculative decoding
|
|
slot.i_batch = batch.n_tokens;
|
|
|
|
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
|
|
|
slot.cache_tokens.push_back(slot.sampled);
|
|
|
|
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
|
|
(int)slot.n_ctx, (int)slot.cache_tokens.size(), (int)slot.truncated);
|
|
}
|
|
slot.n_past = slot.cache_tokens.n_tokens();
|
|
}
|
|
}
|
|
|
|
void server_context::create_checkpoint_at_interval(server_slot & slot, const gpt_params & params_base) {
|
|
if (params_base.do_checkpoint && params_base.ctx_checkpoints_interval > 0) {
|
|
auto pos = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
|
if (slot.checkpoint_pos + params_base.ctx_checkpoints_interval <= 1 + pos) {
|
|
bool created = create_checkpoint(slot);
|
|
if (created) {
|
|
slot.checkpoint_pos = pos;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void server_context::apply_checkpoint(server_slot & slot) {
|
|
llama_pos pos_next = slot.cache_tokens.pos_next(slot.n_past);
|
|
const auto pos_min_thold = std::max(0, pos_next - 1);
|
|
if (slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) {
|
|
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
|
|
|
if (pos_min > pos_min_thold) {
|
|
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int)slot.cache_tokens.size(), slot.id, pos_min);
|
|
|
|
// search for a context checkpoint
|
|
const auto it = std::find_if(
|
|
slot.server_cached_prompt.checkpoints.rbegin(),
|
|
slot.server_cached_prompt.checkpoints.rend(),
|
|
[&](const auto & cur) {
|
|
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
|
|
return cur.pos_min < pos_min_thold;
|
|
}
|
|
);
|
|
|
|
bool do_reset = it == slot.server_cached_prompt.checkpoints.rend();
|
|
|
|
if (!do_reset) {
|
|
// restore the context checkpoint
|
|
const int64_t t_start = ggml_time_us();
|
|
const size_t checkpoint_size = it->data.size();
|
|
const size_t n = llama_state_seq_set_data(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
if (n != checkpoint_size) {
|
|
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
|
do_reset = true;
|
|
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
|
} else {
|
|
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
|
|
slot.n_past = slot.cache_tokens.size_up_to_pos(slot.n_past-1);
|
|
slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt + 1, it->pos_max_prompt));
|
|
slot.n_past_prompt = slot.prompt_tokens.size_up_to_pos(slot.n_past_prompt-1);
|
|
SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024);
|
|
}
|
|
}
|
|
|
|
if (do_reset) {
|
|
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
|
|
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
|
slot.n_past = 0;
|
|
slot.n_past_prompt = 0;
|
|
}
|
|
}
|
|
}
|
|
|
|
{
|
|
// erase any checkpoints with pos_min > pos_min_thold
|
|
for (auto it = slot.server_cached_prompt.checkpoints.begin(); it != slot.server_cached_prompt.checkpoints.end();) {
|
|
const auto & cur = *it;
|
|
if (cur.pos_min > pos_min_thold) {
|
|
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
|
it = slot.server_cached_prompt.checkpoints.erase(it);
|
|
} else {
|
|
++it;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
bool server_context::create_checkpoint(server_slot & slot) {
|
|
bool do_checkpoint = !slot.image_just_processed;
|
|
int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
|
|
const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
|
|
|
|
// no need for empty or small checkpoints
|
|
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 16);
|
|
|
|
// no need to create checkpoints that are too close together
|
|
do_checkpoint = do_checkpoint && (slot.server_cached_prompt.checkpoints.empty() || pos_max > slot.server_cached_prompt.checkpoints.back().pos_max);
|
|
|
|
if (do_checkpoint) {
|
|
const int64_t t_start = ggml_time_us();
|
|
while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.ctx_checkpoints_n) {
|
|
// make room for the new checkpoint, if needed
|
|
const auto & cur = slot.server_cached_prompt.checkpoints.front();
|
|
|
|
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
|
cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024);
|
|
|
|
slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
|
|
}
|
|
|
|
const size_t checkpoint_size = llama_state_seq_get_size(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
|
/*.pos_min = */ pos_min,
|
|
/*.pos_max = */ pos_max,
|
|
/*.pos_min_prompt = */ pos_min + slot.n_past_offset,
|
|
/*.pos_max_prompt = */ pos_max + slot.n_past_offset ,
|
|
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
|
});
|
|
|
|
llama_state_seq_get_data(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
|
|
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB, took %.2f ms)\n",
|
|
(int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024,
|
|
(ggml_time_us() - t_start) / 1000.0);
|
|
}
|
|
return do_checkpoint;
|
|
}
|
|
|
|
void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) {
|
|
if (params_base.cont_batching || batch.n_tokens == 0) {
|
|
for (auto& slot : slots) {
|
|
// this slot still has a prompt to be processed
|
|
if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) {
|
|
auto& prompt_tokens = slot.prompt_tokens;
|
|
|
|
// we haven't tokenized the prompt yet - do it now:
|
|
if (prompt_tokens.empty() || slot.n_prompt_tokens == 0) {
|
|
LOG_VERBOSE("tokenizing prompt", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task}
|
|
});
|
|
|
|
slot.t_start_process_prompt = ggml_time_us();
|
|
slot.t_start_generation = 0;
|
|
|
|
if (slot.infill) {
|
|
const bool add_bos = llama_should_add_bos_token(model);
|
|
bool suff_rm_leading_spc = true;
|
|
if (params_base.input_suffix.find_first_of(' ') == 0 && params_base.input_suffix.size() > 1) {
|
|
params_base.input_suffix.erase(0, 1);
|
|
suff_rm_leading_spc = false;
|
|
}
|
|
|
|
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
|
|
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
|
|
|
|
const int space_token = 29871; // TODO: this should not be hardcoded
|
|
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
|
|
suffix_tokens.erase(suffix_tokens.begin());
|
|
}
|
|
|
|
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
|
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
|
|
|
|
auto embd_inp = params_base.spm_infill ? suffix_tokens : prefix_tokens;
|
|
auto embd_end = params_base.spm_infill ? prefix_tokens : suffix_tokens;
|
|
if (add_bos) {
|
|
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
|
}
|
|
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
|
|
|
const llama_token middle_token = llama_token_middle(model);
|
|
if (middle_token >= 0) {
|
|
embd_inp.push_back(middle_token);
|
|
}
|
|
|
|
prompt_tokens = server_tokens(embd_inp, false);
|
|
}
|
|
else {
|
|
// prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
|
}
|
|
|
|
slot.n_past = 0;
|
|
slot.n_prompt_tokens = prompt_tokens.size();
|
|
|
|
LOG_VERBOSE("prompt tokenized", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_ctx", slot.n_ctx},
|
|
{"n_keep", slot.params.n_keep},
|
|
{"n_prompt_tokens", slot.n_prompt_tokens},
|
|
{"prompt_tokens", prompt_tokens.detokenize(ctx, true)},
|
|
});
|
|
|
|
// empty prompt passed -> release the slot and send empty response
|
|
if (prompt_tokens.empty()) {
|
|
LOG_INFO("empty prompt - releasing slot", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task}
|
|
});
|
|
|
|
slot.state = SLOT_STATE_PROCESSING;
|
|
slot.command = SLOT_COMMAND_NONE;
|
|
send_final_response(slot);
|
|
slot.release();
|
|
slot.print_timings();
|
|
continue;
|
|
}
|
|
|
|
if (slot.embedding) {
|
|
// this prompt is too large to process - discard it
|
|
if (slot.n_prompt_tokens > n_ubatch) {
|
|
slot.state = SLOT_STATE_PROCESSING;
|
|
slot.command = SLOT_COMMAND_NONE;
|
|
slot.release();
|
|
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
|
continue;
|
|
}
|
|
}
|
|
else {
|
|
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
|
// context shift for prompt processing
|
|
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
|
if (!params_base.ctx_shift) {
|
|
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER);
|
|
slot.release();
|
|
continue;
|
|
}
|
|
context_shift_prompt(ctx, slot);
|
|
slot.truncated = true;
|
|
LOG_VERBOSE("input truncated", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_ctx", slot.n_ctx},
|
|
{"n_keep", slot.params.n_keep},
|
|
{"n_left", slot.n_ctx - slot.params.n_keep},
|
|
{"n_prompt_tokens", slot.n_prompt_tokens},
|
|
{"prompt_tokens", prompt_tokens.detokenize(ctx, true)},
|
|
});
|
|
|
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
|
|
|
#ifndef NDEBUG
|
|
// debug
|
|
common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false);
|
|
int32_t back = 1;
|
|
if (slot.cache_tokens.size() && slot.cache_tokens.size() > prefix.first + 20
|
|
&& prefix.second >= back && prefix.first >= back) {
|
|
LLAMA_LOG_INFO("After context shift :\n");
|
|
print_tokens(slot.prompt_tokens, slot.cache_tokens, prefix.second - back, prefix.first - back, 50);
|
|
}
|
|
#endif
|
|
}
|
|
else {
|
|
slot.n_discarded_prompt = 0;
|
|
}
|
|
common_sampler_reset(slot.ctx_sampling);
|
|
|
|
if (!slot.params.cache_prompt) {
|
|
slot.n_past_se = 0;
|
|
slot.ga_i = 0;
|
|
}
|
|
else {
|
|
GGML_ASSERT(slot.ga_n == 1);
|
|
|
|
// reuse any previously computed tokens that are common with the new prompt
|
|
common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, true); // string level match
|
|
common_prefix prefix_nonexact = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false);
|
|
auto n_past0 = slot.cache_tokens.get_common_prefix_exact(prompt_tokens); // token level match
|
|
LLAMA_LOG_INFO("======== Cache: cache_size = %d, n_past0 = %d, n_past1 = %d, n_past_prompt1 = %d, n_past2 = %d, n_past_prompt2 = %d\n", (int32_t)slot.cache_tokens.size(), (int32_t)n_past0, (int32_t)prefix.first, (int32_t)prefix.second, (int32_t)prefix_nonexact.first, (int32_t)prefix_nonexact.second);
|
|
int32_t size_threshold = 20;
|
|
if (prefix.first + size_threshold < prefix_nonexact.first) {
|
|
// LLAMA_LOG_WARN("Common part contains missing or extra space and new line\n");
|
|
prefix = prefix_nonexact;
|
|
}
|
|
slot.n_past = prefix.first;
|
|
slot.n_past_prompt = prefix.second;
|
|
slot.n_past_offset = slot.n_past_prompt - slot.n_past;
|
|
|
|
//if (slot.n_past != slot.n_past_prompt) {
|
|
// LLAMA_LOG_INFO("Mistokenization found and handled successfully.\n");
|
|
//}
|
|
if ((slot.n_past + size_threshold < slot.cache_tokens.size()))
|
|
{
|
|
LLAMA_LOG_WARN("Common part does not match fully\n");
|
|
int32_t back = 4;
|
|
if (prefix.second >= back && prefix.first >= back) {
|
|
print_tokens(slot.prompt_tokens, slot.cache_tokens, prefix.second - back, prefix.first - back, 30);
|
|
}
|
|
}
|
|
|
|
// push the prompt into the sampling context (do not apply grammar)
|
|
for (int i = 0; i < slot.n_past; ++i) {
|
|
common_sampler_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
|
}
|
|
}
|
|
}
|
|
apply_checkpoint(slot);
|
|
if (slot.n_past_prompt == slot.n_prompt_tokens && slot.n_past_prompt > 0) {
|
|
// we have to evaluate at least 1 token to generate logits.
|
|
LOG_INFO("we have to evaluate at least 1 token to generate logits", {
|
|
{ "id_slot", slot.id },
|
|
{ "id_task", slot.id_task }
|
|
});
|
|
|
|
slot.n_past_prompt--;
|
|
slot.n_past--;
|
|
if (slot.ga_i > 0) {
|
|
slot.n_past_se--;
|
|
}
|
|
}
|
|
slot.n_prompt_tokens_cache = slot.n_past_prompt;
|
|
slot.n_prompt_tokens_processed = 0;
|
|
}
|
|
|
|
if (slot.embedding) {
|
|
// cannot fit the prompt in the current batch - will try next iter
|
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// check that we are in the right batch_type, if not defer the slot
|
|
bool slot_type = slot.embedding ? 1 : 0;
|
|
if (batch_type == -1) {
|
|
batch_type = slot_type;
|
|
}
|
|
else if (batch_type != slot_type) {
|
|
continue;
|
|
}
|
|
|
|
// keep only the common part
|
|
// remove the non-common part from the cache
|
|
if (slot.n_past < 0)
|
|
{
|
|
slot.n_past = 0;
|
|
}
|
|
slot.cache_tokens.keep_first(slot.n_past);
|
|
int p0 = (int)system_tokens.size() + slot.n_past;
|
|
p0 = system_tokens.size() + slot.cache_tokens.pos_next();
|
|
if (!llama_kv_cache_seq_rm(ctx, slot.id, p0, -1)) {
|
|
// could not partially delete (likely using a non-Transformer model)
|
|
llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
|
|
|
|
p0 = (int)system_tokens.size();
|
|
if (p0 != 0) {
|
|
// copy over the system prompt when there is one
|
|
llama_kv_cache_seq_cp(ctx, 0, slot.id, -1, -1);
|
|
}
|
|
|
|
// there is no common part left (except for the system prompt)
|
|
slot.n_past = 0;
|
|
slot.n_past_se = 0;
|
|
slot.ga_i = 0;
|
|
// TODO: is the system prompt ever in the sampling context?
|
|
common_sampler_reset(slot.ctx_sampling);
|
|
}
|
|
|
|
LOG_INFO("kv cache rm [p0, end)", {
|
|
{ "id_slot", slot.id },
|
|
{ "id_task", slot.id_task },
|
|
{ "p0", p0 }
|
|
});
|
|
|
|
// check if we should process the image
|
|
if (slot.n_past_prompt < slot.n_prompt_tokens
|
|
&& slot.prompt_tokens[slot.n_past_prompt] == LLAMA_TOKEN_NULL) {
|
|
// process the image
|
|
size_t n_tokens_out = 0;
|
|
llama_pos p1 = slot.cache_tokens.pos_next() + slot.n_past_prompt - slot.n_past; // add offset to prompt
|
|
int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past_prompt, p1, slot.id, n_tokens_out);
|
|
if (res != 0) {
|
|
LLAMA_LOG_ERROR("failed to process image, res = %d\n", res);
|
|
slot.release();
|
|
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
|
|
continue;
|
|
}
|
|
|
|
// add the image chunk to cache
|
|
{
|
|
const auto& chunk = slot.prompt_tokens.find_chunk(slot.n_past_prompt);
|
|
slot.cache_tokens.push_back(chunk.get()); // copy
|
|
}
|
|
|
|
slot.n_past += n_tokens_out;
|
|
slot.n_past_prompt += n_tokens_out;
|
|
slot.n_prompt_tokens_processed += n_tokens_out;
|
|
slot.image_just_processed = true; // do not checkpoint right after an image chunk
|
|
}
|
|
|
|
|
|
|
|
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
|
|
|
int32_t ga_i = slot.ga_i;
|
|
int32_t ga_n = slot.ga_n;
|
|
int32_t ga_w = slot.ga_w;
|
|
|
|
// add prompt tokens for processing in the current batch
|
|
// TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
|
|
while (slot.n_past_prompt < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
|
// get next token to process
|
|
llama_token cur_tok = slot.prompt_tokens[slot.n_past_prompt];
|
|
if (cur_tok == LLAMA_TOKEN_NULL) {
|
|
break; // end of text chunk
|
|
}
|
|
if (slot.ga_n != 1) {
|
|
while (slot_npast >= ga_i + ga_w) {
|
|
const int bd = (ga_w / ga_n) * (ga_n - 1);
|
|
slot_npast -= bd;
|
|
ga_i += ga_w / ga_n;
|
|
}
|
|
}
|
|
|
|
int p0 = system_tokens.size() + slot.cache_tokens.pos_next();
|
|
common_batch_add(batch, cur_tok, p0, { slot.id }, slot.need_embd());
|
|
|
|
slot.cache_tokens.push_back(cur_tok);
|
|
|
|
|
|
slot.n_prompt_tokens_processed++;
|
|
slot_npast++;
|
|
slot.n_past_prompt++;
|
|
slot.n_past++;
|
|
slot.image_just_processed = false;
|
|
if (params_base.do_checkpoint && slot.n_prompt_tokens - slot.n_past_prompt == params_base.ctx_checkpoints_tolerance) {
|
|
slot.do_checkpoint = true;
|
|
break;
|
|
}
|
|
|
|
}
|
|
LOG_VERBOSE("prompt processing progress", {
|
|
{"id_slot", slot.id},
|
|
{"n_past", slot.n_past},
|
|
{"n_ctx", n_ctx},
|
|
{"n_tokens", batch.n_tokens},
|
|
{"progress", (float)slot.n_prompt_tokens_processed / slot.n_prompt_tokens},
|
|
});
|
|
|
|
// entire prompt has been processed - start decoding new tokens
|
|
if (slot.n_past_prompt == slot.n_prompt_tokens) {
|
|
slot.state = SLOT_STATE_PROCESSING;
|
|
slot.command = SLOT_COMMAND_NONE;
|
|
GGML_ASSERT(batch.n_tokens > 0);
|
|
GGML_ASSERT((size_t)slot.n_prompt_tokens == slot.prompt_tokens.size());
|
|
common_sampler_reset(slot.ctx_sampling);
|
|
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
|
llama_token id = slot.prompt_tokens[i];
|
|
if (id != LLAMA_TOKEN_NULL) {
|
|
common_sampler_accept(slot.ctx_sampling, ctx, id, false);
|
|
}
|
|
}
|
|
|
|
// extract the logits only for the last token
|
|
batch.logits[batch.n_tokens - 1] = true;
|
|
|
|
slot.n_decoded = 0;
|
|
slot.i_batch = batch.n_tokens - 1;
|
|
|
|
LOG_VERBOSE("prompt done", {
|
|
{"id_slot", slot.id},
|
|
{"n_past", slot.n_past},
|
|
{"n_ctx", n_ctx},
|
|
{"n_tokens", batch.n_tokens},
|
|
});
|
|
}
|
|
}
|
|
|
|
if (batch.n_tokens >= n_batch) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void server_context::extend_context(const int32_t n_tokens) {
|
|
for (auto& slot : slots) {
|
|
if (slot.ga_n != 1) {
|
|
// context extension via Self-Extend
|
|
// TODO: simplify and/or abstract this
|
|
while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
|
|
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
|
|
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
|
|
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
|
|
|
|
LOG_TEE("\n");
|
|
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
|
|
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
|
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
|
|
|
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
|
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
|
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
|
|
|
slot.n_past_se -= bd;
|
|
|
|
slot.ga_i += slot.ga_w / slot.ga_n;
|
|
|
|
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
|
|
}
|
|
|
|
slot.n_past_se += n_tokens;
|
|
}
|
|
}
|
|
}
|
|
|
|
void server_context::speculative_decoding_accept() {
|
|
for (auto& slot : slots) {
|
|
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) {
|
|
continue;
|
|
}
|
|
|
|
size_t n_draft = slot.drafted.size();
|
|
|
|
apply_server_biases(slot);
|
|
|
|
// the accepted tokens from the speculation
|
|
const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, slot.i_batch_dft, slot.drafted);
|
|
|
|
if (slot.has_mtp) {
|
|
llama_context * mtp_ctx = common_speculative_get_mtp_ctx(slot.spec);
|
|
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
|
|
|
|
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
|
if (!ids.empty()) {
|
|
const float* emb = llama_get_embeddings(ctx);
|
|
if (emb) {
|
|
slot.mtp_hidden_state.resize(ids.size() * n_embd);
|
|
memcpy(slot.mtp_hidden_state.data(), emb, ids.size() * n_embd * sizeof(float));
|
|
}
|
|
} else {
|
|
const float* emb0 = llama_get_embeddings_ith(ctx, 0);
|
|
if (emb0) {
|
|
slot.mtp_hidden_state.resize(n_embd);
|
|
memcpy(slot.mtp_hidden_state.data(), emb0, n_embd * sizeof(float));
|
|
}
|
|
}
|
|
|
|
llama_set_draft_input_hidden_state(mtp_target, slot.mtp_hidden_state.data());
|
|
|
|
int32_t n_past_base = slot.n_past - (slot.drafted.size() + 1);
|
|
mtp_accept_tokens(mtp_target, ids, n_past_base, slot.id);
|
|
}
|
|
|
|
slot.i_batch_dft.clear();
|
|
slot.drafted.clear();
|
|
|
|
slot.n_past += ids.size();
|
|
slot.n_decoded += ids.size();
|
|
const int64_t t_current = ggml_time_us();
|
|
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
|
|
|
// update how many tokens out of those tested were accepted
|
|
slot.n_draft_accepted += ids.size() - 1;
|
|
|
|
// inform the speculative decoding about the number of accepted tokens
|
|
common_speculative_accept(slot.spec, ids.size() - 1);
|
|
|
|
// rollback to the state before sampling the draft tokens
|
|
slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft);
|
|
|
|
// add accepted tokens to the prompt
|
|
slot.cache_tokens.insert({ ids.begin(), ids.end() - 1 });
|
|
slot.sampled = ids.back(); // last accepted token
|
|
slot.n_past = slot.cache_tokens.n_tokens();
|
|
|
|
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
|
|
|
|
for (size_t i = 0; i < ids.size(); ++i) {
|
|
completion_token_output result;
|
|
|
|
result.tok = ids[i];
|
|
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
|
result.prob = 1.0f; // set later
|
|
|
|
if (slot.sparams.n_probs > 0) {
|
|
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, i);
|
|
}
|
|
|
|
if (slot.n_buffer == 0 || !params_base.can_ban_phrases) {
|
|
if (!process_token(result, slot)) {
|
|
// release slot because of stop condition
|
|
slot.i_batch_dft.push_back(batch.n_tokens);
|
|
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
|
slot.cache_tokens.push_back(slot.sampled);
|
|
slot.n_past++;
|
|
send_final_response(slot);
|
|
release_slot_after_final_response(slot);
|
|
break;
|
|
}
|
|
} else {
|
|
buffer_and_check_string_ban(slot, result);
|
|
if (slot.task == nullptr) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
common_sampler_review(slot.ctx_sampling, slot.token_buffer.size(), slot.rewind_status);
|
|
|
|
update_allowlist_state(slot);
|
|
}
|
|
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past);
|
|
LOG_VERBOSE("speculative decoding result", {
|
|
{"id_slot", slot.id},
|
|
{"accepted", (int)ids.size() - 1},
|
|
{"total", (int)slot.drafted.size()},
|
|
{"new_n_past", slot.n_past}
|
|
});
|
|
}
|
|
}
|
|
|
|
|
|
bool server_context::accept_special_token(const server_slot& slot, const llama_token token) {
|
|
return params_base.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end();
|
|
}
|
|
|
|
void server_context::release_slot_after_final_response(server_slot & slot) {
|
|
slot.print_timings();
|
|
if (params_base.do_checkpoint) {
|
|
create_checkpoint(slot);
|
|
}
|
|
slot.release();
|
|
slot.released = true;
|
|
metrics.on_prediction(slot);
|
|
}
|
|
|
|
void server_context::send_token_results(completion_token_outputs& results, server_slot& slot, int32_t n) {
|
|
int count = 0;
|
|
bool released = false;
|
|
|
|
int32_t start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
|
|
|
|
for (auto& it : results) {
|
|
bool has_next = process_token(it, slot);
|
|
|
|
// Clean up positional bans for the token we just confirmed/sent
|
|
slot.positional_bans.erase(start_pos + count);
|
|
|
|
count++;
|
|
if (!has_next) {
|
|
if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
|
|
continue;
|
|
}
|
|
slot.i_batch = batch.n_tokens;
|
|
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
|
slot.cache_tokens.push_back(slot.sampled);
|
|
slot.n_past++;
|
|
send_final_response(slot);
|
|
release_slot_after_final_response(slot);
|
|
released = true;
|
|
break;
|
|
}
|
|
if (n > 0 && count >= n) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (!released && slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) {
|
|
slot.i_batch = batch.n_tokens;
|
|
common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true);
|
|
slot.cache_tokens.push_back(slot.sampled);
|
|
slot.n_past++;
|
|
send_final_response(slot);
|
|
release_slot_after_final_response(slot);
|
|
}
|
|
|
|
if (count > 0) {
|
|
slot.sampled = results[results.size()-1].tok;
|
|
results.erase(results.begin(), results.begin() + count);
|
|
}
|
|
|
|
}
|
|
|
|
inline int32_t check_ban_phrase(server_slot& slot) {
|
|
if (slot.token_buffer.empty()) return -1;
|
|
|
|
std::string string_buffer;
|
|
std::vector<size_t> token_offsets;
|
|
|
|
for (const auto& it : slot.token_buffer) {
|
|
token_offsets.push_back(string_buffer.size());
|
|
string_buffer += it.text_to_send;
|
|
}
|
|
|
|
size_t best_start = std::string::npos;
|
|
bool found = false;
|
|
std::string string_buffer_lower = string_lower(string_buffer);
|
|
|
|
// 1. Check strings
|
|
for (const auto& phrase : slot.ban_phrases) {
|
|
size_t start = string_buffer_lower.find(phrase);
|
|
if (start != std::string::npos) {
|
|
if (start < best_start) {
|
|
best_start = start;
|
|
found = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
// 2. Check regex
|
|
for (const auto& pattern : slot.ban_regex) {
|
|
try {
|
|
std::regex re(pattern);
|
|
std::smatch match;
|
|
if (std::regex_search(string_buffer, match, re)) {
|
|
if (match.position() < best_start) {
|
|
best_start = match.position();
|
|
found = true;
|
|
}
|
|
}
|
|
} catch (...) { continue; }
|
|
}
|
|
|
|
// 3. Check regex case insensitive
|
|
for (const auto& pattern : slot.ban_regex_ci) {
|
|
try {
|
|
std::regex re(pattern, std::regex_constants::icase);
|
|
std::smatch match;
|
|
if (std::regex_search(string_buffer, match, re)) {
|
|
if (match.position() < best_start) {
|
|
best_start = match.position();
|
|
found = true;
|
|
}
|
|
}
|
|
} catch (...) { continue; }
|
|
}
|
|
|
|
if (found) {
|
|
int32_t token_idx = -1;
|
|
for (size_t i = 0; i < token_offsets.size(); ++i) {
|
|
size_t len = (i == token_offsets.size() - 1)
|
|
? string_buffer.size() - token_offsets[i]
|
|
: token_offsets[i+1] - token_offsets[i];
|
|
|
|
if (best_start >= token_offsets[i] && best_start < token_offsets[i] + len) {
|
|
token_idx = (int32_t)i;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (token_idx != -1) {
|
|
int32_t abs_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1 + token_idx;
|
|
return abs_pos;
|
|
}
|
|
}
|
|
|
|
return -1;
|
|
}
|
|
|
|
inline void rewind_context(server_slot& slot, int32_t ban_pos) {
|
|
slot.rewind_count++;
|
|
|
|
int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
|
|
int32_t n_keep_buffer = ban_pos - buffer_start_pos;
|
|
if (n_keep_buffer < 0) n_keep_buffer = 0;
|
|
|
|
if (slot.banned_n != 0) {
|
|
int32_t n = 0;
|
|
for (auto result = slot.token_buffer.begin() + n_keep_buffer; result != slot.token_buffer.end(); result++) {
|
|
llama_token banned_tok = result->tok;
|
|
|
|
if (n == 0) {
|
|
LLAMA_LOG_DEBUG("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n",
|
|
ban_pos, banned_tok, result->text_to_send.c_str());
|
|
}
|
|
|
|
slot.positional_bans[ban_pos].insert(banned_tok);
|
|
n++;
|
|
if (slot.banned_n > 0 && n == slot.banned_n) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
int32_t n_rewind_total = (slot.n_past + 1) - ban_pos;
|
|
|
|
size_t n_keep_cache = 0;
|
|
if (ban_pos > 0) {
|
|
n_keep_cache = (size_t)(ban_pos - 1);
|
|
}
|
|
|
|
if (n_keep_cache > slot.cache_tokens.size()) {
|
|
n_keep_cache = slot.cache_tokens.size();
|
|
}
|
|
|
|
if (n_keep_cache < slot.cache_tokens.size()) {
|
|
slot.sampled = slot.cache_tokens[n_keep_cache];
|
|
} else {
|
|
slot.sampled = 0;
|
|
}
|
|
|
|
// Truncate cache
|
|
slot.cache_tokens.keep_first(n_keep_cache);
|
|
slot.n_past = slot.cache_tokens.n_tokens();
|
|
|
|
// Remove from KV cache
|
|
llama_kv_cache_seq_rm(slot.ctx, slot.id, slot.n_past, -1);
|
|
|
|
// Truncate buffer
|
|
slot.token_buffer.resize(n_keep_buffer);
|
|
|
|
// Adjust decoded count
|
|
if (slot.saturate_predict) {
|
|
slot.n_decoded -= n_rewind_total;
|
|
if (slot.n_decoded < 0) slot.n_decoded = 0;
|
|
}
|
|
}
|
|
|
|
void server_context::buffer_and_check_string_ban(server_slot & slot, completion_token_output & result) {
|
|
slot.token_buffer.push_back(result);
|
|
|
|
bool next_token = has_next_token(result, slot);
|
|
// If buffer full or generation stopped, we might send tokens
|
|
bool buffer_full = slot.token_buffer.size() >= slot.n_buffer;
|
|
|
|
int32_t ban_pos = -1;
|
|
bool sent_results = false;
|
|
|
|
// Always reset logit bias to base before checking bans
|
|
slot.ctx_sampling->params.logit_bias = slot.logit_bias;
|
|
|
|
if (slot.ban_phrases.size() > 0 || slot.ban_regex.size() > 0 || slot.ban_regex_ci.size() > 0) {
|
|
ban_pos = check_ban_phrase(slot);
|
|
}
|
|
|
|
bool allow_rewind = true;
|
|
|
|
if (ban_pos >= 0) {
|
|
if (slot.rewind_count_max == -1) {
|
|
// Automatic / Heuristic logic
|
|
// Account for strings + regex + regex_ci
|
|
size_t total_bans = slot.ban_phrases.size() + slot.ban_regex.size() + slot.ban_regex_ci.size();
|
|
|
|
// Heuristic: Allow if under 20 OR under 2 * total_bans
|
|
// Conversely: Stop if >= 20 AND > 2 * total_bans
|
|
if (slot.rewind_count >= 20 && slot.rewind_count > 2 * total_bans) {
|
|
allow_rewind = false;
|
|
}
|
|
}
|
|
else if (slot.rewind_count_max > 0) {
|
|
// Strict limit logic
|
|
if (slot.rewind_count >= slot.rewind_count_max) {
|
|
allow_rewind = false;
|
|
}
|
|
}
|
|
// If slot.rewind_count_max == 0, allow_rewind remains true (Infinite)
|
|
}
|
|
|
|
if (ban_pos >= 0 && allow_rewind) {
|
|
rewind_context(slot, ban_pos);
|
|
slot.rewind_status = true;
|
|
}
|
|
else if (buffer_full || !next_token) {
|
|
slot.rewind_status = false;
|
|
slot.rewind_count = 0;
|
|
|
|
if (!next_token) {
|
|
// send all remaining tokens
|
|
send_token_results(slot.token_buffer, slot);
|
|
}
|
|
else {
|
|
// send 1 token from the front (FIFO)
|
|
send_token_results(slot.token_buffer, slot, 1);
|
|
}
|
|
}
|
|
else {
|
|
// buffer the result, wait for more tokens to validate string
|
|
slot.sampled = result.tok;
|
|
}
|
|
}
|
|
|
|
void server_context::update_allowlist_state(server_slot& slot) {
|
|
const auto& kws = slot.allow_kws;
|
|
auto& idx = slot.allow_idx;
|
|
if ((slot.allow_kw_delay > slot.n_decoded) || (idx >= kws.size())) {
|
|
return;
|
|
}
|
|
|
|
// search for keyword
|
|
auto kw = kws[idx];
|
|
auto pos = slot.generated_text.find(kw, std::max(0, slot.last_gentxt_size - (int32_t)kw.length() + 1));
|
|
while (pos != std::string::npos) {
|
|
if (++idx >= kws.size()) {
|
|
break;
|
|
}
|
|
kw = kws[idx];
|
|
pos = slot.generated_text.find(kw, pos + 1);
|
|
}
|
|
}
|
|
|
|
void server_context::process_batch_tokens(int32_t & n_batch) {
|
|
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
|
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
|
extend_context(n_tokens);
|
|
|
|
llama_batch batch_view = {
|
|
n_tokens,
|
|
batch.token + i,
|
|
nullptr,
|
|
batch.pos + i,
|
|
batch.n_seq_id + i,
|
|
batch.seq_id + i,
|
|
batch.logits + i,
|
|
0, 0, 0, // unused
|
|
};
|
|
|
|
const int ret = llama_decode(ctx, batch_view);
|
|
if (ret != 0) {
|
|
if (n_batch == 1 || ret < 0) {
|
|
int user_cancel = -3;
|
|
if (ret == user_cancel) {
|
|
LLAMA_LOG_INFO("Decode process is cancelled by user.\n");
|
|
}
|
|
else {
|
|
// if you get here, it means the KV cache is full - try increasing it via the context size
|
|
LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
|
|
{"i", i},
|
|
{"n_batch", ret},
|
|
{"ret", ret},
|
|
});
|
|
}
|
|
|
|
for (auto& slot : slots) {
|
|
slot.state = SLOT_STATE_PROCESSING;
|
|
slot.command = SLOT_COMMAND_NONE;
|
|
slot.release();
|
|
if (ret != user_cancel) {
|
|
LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size());
|
|
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
|
}
|
|
}
|
|
break; // break loop of n_batch
|
|
}
|
|
// retry with half the batch size to try to find a free slot in the KV cache
|
|
n_batch /= 2;
|
|
i -= n_batch;
|
|
|
|
LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", {
|
|
{"i", i},
|
|
{"n_batch", n_batch},
|
|
{"ret", ret},
|
|
});
|
|
|
|
continue; // continue loop of n_batch
|
|
}
|
|
|
|
bool mtp_warmup_needed = false;
|
|
std::vector<float> batch_mtp_hidden_state;
|
|
if (params_base.has_mtp) {
|
|
for (auto& slot : slots) {
|
|
if ((slot.state == SLOT_STATE_PROCESSING && slot.n_decoded == 0) ||
|
|
(slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT)) {
|
|
bool has_tokens_for_slot = (batch_view.n_tokens > 0 && batch_view.n_seq_id[0] > 0 && batch_view.seq_id[0][0] == slot.id);
|
|
if (has_tokens_for_slot) {
|
|
mtp_warmup_needed = true;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if (mtp_warmup_needed) {
|
|
const float* emb = llama_get_embeddings(ctx);
|
|
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
|
const int n_toks = batch_view.n_tokens;
|
|
if (emb) {
|
|
batch_mtp_hidden_state.resize(n_toks * n_embd);
|
|
memcpy(batch_mtp_hidden_state.data(), emb, n_toks * n_embd * sizeof(float));
|
|
}
|
|
}
|
|
}
|
|
|
|
for (auto& slot : slots) {
|
|
bool is_active_slot = (slot.state == SLOT_STATE_PROCESSING);
|
|
|
|
if (!is_active_slot || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
|
|
// save checkpoint during prompt processing
|
|
if (slot.command == SLOT_COMMAND_LOAD_PROMPT) {
|
|
if (slot.do_checkpoint) {
|
|
create_checkpoint(slot);
|
|
} else {
|
|
create_checkpoint_at_interval(slot, params_base);
|
|
}
|
|
}
|
|
continue; // continue loop of slots
|
|
}
|
|
|
|
// prompt evaluated for embedding
|
|
if (slot.embedding) {
|
|
send_embedding(slot, batch_view);
|
|
slot.release();
|
|
slot.i_batch = -1;
|
|
continue; // continue loop of slots
|
|
}
|
|
|
|
if (slot.n_decoded == 0 && slot.can_speculate()) {
|
|
common_speculative_begin(slot.spec, slot.cache_tokens.get_text_tokens());
|
|
}
|
|
|
|
if (slot.i_batch_dft.size() > 0) {
|
|
continue; // sample using speculative decoding
|
|
}
|
|
|
|
// RESTORE AND APPLY POSITIONAL BANS
|
|
slot.ctx_sampling->params.logit_bias = slot.logit_bias;
|
|
auto ban_it = slot.positional_bans.find(slot.n_past);
|
|
if (ban_it != slot.positional_bans.end()) {
|
|
for (llama_token tok : ban_it->second) {
|
|
slot.ctx_sampling->params.logit_bias[tok] += slot.ban_phrases_bias;
|
|
}
|
|
}
|
|
|
|
completion_token_output result;
|
|
const int tok_idx = slot.i_batch - i;
|
|
|
|
if (params_base.has_mtp && slot.n_decoded == 0) {
|
|
const float* emb_i = llama_get_embeddings_ith(ctx, tok_idx);
|
|
if (emb_i) {
|
|
const int n_embd = llama_model_n_embd(llama_get_model(ctx));
|
|
slot.mtp_hidden_state.resize(n_embd);
|
|
memcpy(slot.mtp_hidden_state.data(), emb_i, n_embd * sizeof(float));
|
|
}
|
|
}
|
|
|
|
apply_server_biases(slot);
|
|
|
|
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, tok_idx);
|
|
|
|
common_sampler_accept(slot.ctx_sampling, ctx, id, true);
|
|
|
|
slot.n_decoded += 1;
|
|
const int64_t t_current = ggml_time_us();
|
|
|
|
if (slot.n_decoded == 1) {
|
|
slot.t_start_generation = ggml_time_us();
|
|
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
|
metrics.on_prompt_eval(slot);
|
|
// create checkpoint after prompt processing ends
|
|
if (params_base.ctx_checkpoints_tolerance<=0 && params_base.do_checkpoint) {
|
|
create_checkpoint(slot);
|
|
}
|
|
}
|
|
|
|
// create checkpoint during generation
|
|
if (slot.n_decoded > 1) {
|
|
create_checkpoint_at_interval(slot, params_base);
|
|
}
|
|
|
|
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
|
|
|
result.tok = id;
|
|
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
|
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
|
if (slot.sparams.n_probs > 0) {
|
|
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
|
|
}
|
|
|
|
// no ban string for recurrent/hybrid model
|
|
if (slot.n_buffer == 0 || !params_base.can_ban_phrases) {
|
|
slot.token_buffer = { result };
|
|
send_token_results(slot.token_buffer, slot);
|
|
} else {
|
|
// buffer the result and check string ban.
|
|
// if ban, we need to go back, apply logit bias and regenerate
|
|
buffer_and_check_string_ban(slot, result);
|
|
}
|
|
|
|
common_sampler_review(slot.ctx_sampling, slot.token_buffer.size(), slot.rewind_status);
|
|
|
|
update_allowlist_state(slot);
|
|
|
|
slot.i_batch = -1;
|
|
}
|
|
if (mtp_warmup_needed && !batch_mtp_hidden_state.empty()) {
|
|
llama_context * mtp_ctx = nullptr;
|
|
for (auto & slot : slots) {
|
|
if (slot.spec && slot.has_mtp) {
|
|
llama_context * mc = common_speculative_get_mtp_ctx(slot.spec);
|
|
if (mc) { mtp_ctx = mc; break; }
|
|
}
|
|
}
|
|
llama_context * mtp_target = mtp_ctx ? mtp_ctx : ctx;
|
|
llama_set_draft_input_hidden_state(mtp_target, batch_mtp_hidden_state.data());
|
|
mtp_update_kv_cache(mtp_target, batch_view, true);
|
|
}
|
|
|
|
// speculative decoding - main model sample and accept
|
|
speculative_decoding_accept();
|
|
}
|
|
}
|
|
|
|
void server_context::update_slots() {
|
|
if (system_need_update) {
|
|
system_prompt_update();
|
|
}
|
|
// release slots
|
|
release_slots();
|
|
|
|
// check if all slots are idle
|
|
if (slots_idle()) {
|
|
return;
|
|
}
|
|
|
|
{
|
|
LOG_VERBOSE("posting NEXT_RESPONSE", {});
|
|
server_task task;
|
|
task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
|
|
task.id_target = -1;
|
|
|
|
queue_tasks.post(std::move(task));
|
|
}
|
|
|
|
// apply context-shift if needed
|
|
// TODO: simplify and improve
|
|
context_shift();
|
|
|
|
// start populating the batch for this iteration
|
|
common_batch_clear(batch);
|
|
|
|
// frist, add sampled tokens from any ongoing sequences
|
|
add_sampled_tokens(); // Prepare batch for inference
|
|
|
|
// process in chunks of params.n_batch
|
|
int32_t n_batch = llama_n_batch(ctx);
|
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
|
|
|
// track if this is an embedding or non-embedding batch
|
|
// if we've added sampled tokens above, we are in non-embedding mode
|
|
// -1: none, 0: non-embedding, 1: embedding
|
|
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
|
|
|
// next, batch any pending prompts without exceeding n_batch
|
|
batch_pending_prompt(n_ubatch, n_batch, batch_type); // Prepare batch for prompt process
|
|
|
|
if (batch.n_tokens == 0) {
|
|
LOG_VERBOSE("no tokens to decode", {});
|
|
return;
|
|
}
|
|
|
|
LOG_VERBOSE("decoding batch", {
|
|
{"n_tokens", batch.n_tokens},
|
|
});
|
|
|
|
// make sure we're in the right embedding mode
|
|
llama_set_embeddings(ctx, batch_type == 1);
|
|
|
|
// process the created batch of tokens
|
|
process_batch_tokens(n_batch); // Decode with batch
|
|
|
|
LOG_VERBOSE("run slots completed", {});
|
|
}
|
|
|
|
json server_context::model_meta() const {
|
|
return json{
|
|
{"vocab_type", llama_vocab_type(llama_model_get_vocab(model))},
|
|
{"n_vocab", llama_n_vocab(model)},
|
|
{"n_ctx_train", llama_n_ctx_train(model)},
|
|
{"n_embd", llama_model_n_embd(model)},
|
|
{"n_params", llama_model_n_params(model)},
|
|
{"size", llama_model_size(model)},
|
|
};
|
|
}
|