diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 8dc218b1..4d3f462b 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -71,6 +71,8 @@ add_library(${TARGET} STATIC json-schema-to-grammar.cpp train.h train.cpp + log.cpp + log.h ngram-cache.h ngram-cache.cpp speculative.cpp diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 748e0a22..0ab42738 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -13,6 +13,120 @@ using json = nlohmann::ordered_json; +static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, + const common_regex & prefix, + size_t rstrip_prefix = 0) { + static const std::vector> args_paths = { { "arguments" } }; + if (auto res = builder.try_find_regex(prefix)) { + builder.move_back(rstrip_prefix); + auto tool_calls = builder.consume_json_with_dumped_args(args_paths); + if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call array"); + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) { + std::string arguments; + if (builder.is_partial()) { + arguments = (json{ + { "code", code + builder.healing_marker() } + }) + .dump(); + auto idx = arguments.find(builder.healing_marker()); + if (idx != std::string::npos) { + arguments.resize(idx); + } + } else { + arguments = (json{ + { "code", code } + }) + .dump(); + } + return arguments; +} + +/** + * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. + * Aggregates the prefix, suffix and in-between text into the content. + */ +static void parse_json_tool_calls( + common_chat_msg_parser & builder, + const std::optional & block_open, + const std::optional & function_regex_start_only, + const std::optional & function_regex, + const common_regex & close_regex, + const std::optional & block_close, + bool allow_raw_python = false, + const std::function & get_function_name = + nullptr) { + auto parse_tool_calls = [&]() { + size_t from = std::string::npos; + auto first = true; + while (true) { + auto start_pos = builder.pos(); + auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) : + function_regex ? builder.try_find_regex(*function_regex, from) : + std::nullopt; + + if (res) { + std::string name; + if (get_function_name) { + name = get_function_name(*res); + } else { + GGML_ASSERT(res->groups.size() == 2); + name = builder.str(res->groups[1]); + } + first = false; + if (name.empty()) { + // get_function_name signalled us that we should skip this match and treat it as content. + from = res->groups[0].begin + 1; + continue; + } + from = std::string::npos; + + auto maybe_raw_python = name == "python" && allow_raw_python; + if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { + if (auto arguments = builder.try_consume_json_with_dumped_args({ {} })) { + if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(close_regex); + } + continue; + } + if (maybe_raw_python) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + if (!builder.add_tool_call(name, "", arguments)) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + return; + } + throw common_chat_msg_partial_exception("incomplete tool call"); + } else { + builder.move_to(start_pos); + } + break; + } + if (block_close) { + builder.consume_regex(*block_close); + } + builder.consume_spaces(); + builder.add_content(builder.consume_rest()); + }; + if (block_open) { + if (auto res = builder.try_find_regex(*block_open)) { + parse_tool_calls(); + } else { + builder.add_content(builder.consume_rest()); + } + } else { + parse_tool_calls(); + } +} + common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax) : input_(input), is_partial_(is_partial), syntax_(syntax) { @@ -78,6 +192,38 @@ bool common_chat_msg_parser::add_tool_calls(const json & arr) { } return true; } + +bool common_chat_msg_parser::add_tool_call_short_form(const json& tool_call) { + if (!tool_call.is_object() || tool_call.size() != 1) { + return false; + } + + // Get the tool name (the single key in the object) + auto it = tool_call.begin(); + std::string name = it.key(); + + if (name.empty()) { + return false; + } + + // Get the arguments (the nested object) + const json& args_json = it.value(); + std::string arguments = ""; + + if (args_json.is_object()) { + arguments = args_json.dump(); + } + else if (args_json.is_string()) { + arguments = args_json; + } + else if (!args_json.is_null()) { + // For other types, convert to string representation + arguments = args_json.dump(); + } + + return add_tool_call(name, "", arguments); +} + void common_chat_msg_parser::finish() { if (!is_partial_ && pos_ != input_.size()) { throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_)); @@ -503,3 +649,857 @@ std::optional common_chat_msg_parse void common_chat_msg_parser::clear_tools() { result_.tool_calls.clear(); } + +/** + * All common_chat_parse_* moved from chat.cpp to chat-parser.cpp below + * to reduce incremental compile time for parser changes. + */ +static void common_chat_parse_generic(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const std::vector> content_paths = { + {"response"}, + }; + static const std::vector> args_paths = { + {"tool_call", "arguments"}, + {"tool_calls", "arguments"}, + }; + auto data = builder.consume_json_with_dumped_args(args_paths, content_paths); + if (data.value.contains("tool_calls")) { + if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool calls"); + } + } else if (data.value.contains("tool_call")) { + if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } else if (data.value.contains("response")) { + const auto & response = data.value.at("response"); + builder.add_content(response.is_string() ? response.template get() : response.dump(2)); + if (data.is_partial) { + throw common_chat_msg_partial_exception("incomplete response"); + } + } else { + throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON"); + } +} + +static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); + parse_prefixed_json_tool_call_array(builder, prefix); +} + +static void common_chat_parse_magistral(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("[THINK]", "[/THINK]"); + + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); + parse_prefixed_json_tool_call_array(builder, prefix); +} + +static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); + + static const common_regex start_action_regex("<\\|START_ACTION\\|>"); + static const common_regex end_action_regex("<\\|END_ACTION\\|>"); + static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); + static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); + + if (auto res = builder.try_find_regex(start_action_regex)) { + // If we didn't extract thoughts, prelude includes them. + auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); + for (const auto & tool_call : tool_calls.value) { + std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; + std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; + std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; + if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + if (tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(end_action_regex); + } else if (auto res = builder.try_find_regex(start_response_regex)) { + if (!builder.try_find_regex(end_response_regex)) { + builder.add_content(builder.consume_rest()); + throw common_chat_msg_partial_exception(end_response_regex.str()); + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { + builder.try_parse_reasoning("", ""); + + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex function_regex( + "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); + static const common_regex close_regex("\\}\\s*"); + + static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); + static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); + + if (with_builtin_tools) { + static const common_regex builtin_call_regex("<\\|python_tag\\|>"); + if (auto res = builder.try_find_regex(builtin_call_regex)) { + auto fun_res = builder.consume_regex(function_name_regex); + auto function_name = builder.str(fun_res.groups[1]); + + common_healing_marker healing_marker; + json args = json::object(); + while (true) { + if (auto arg_res = builder.try_consume_regex(arg_name_regex)) { + auto arg_name = builder.str(arg_res->groups[1]); + auto partial = builder.consume_json(); + args[arg_name] = partial.json; + healing_marker.marker = partial.healing_marker.marker; + healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker; + builder.consume_spaces(); + if (!builder.try_consume_literal(",")) { + break; + } + } else { + break; + } + } + builder.consume_literal(")"); + builder.consume_spaces(); + + auto arguments = args.dump(); + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + return; + } + } + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ function_regex, + /* function_regex= */ std::nullopt, + close_regex, + std::nullopt); + +} + +static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); + static const common_regex tool_calls_end("<|tool▁calls▁end|>"); + static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n"); + static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); + + parse_json_tool_calls( + builder, + /* block_open= */ tool_calls_begin, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + tool_calls_end); +} + +static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) { + static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)"); + + static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>"); + static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); + static const common_regex tool_calls_end("<|tool▁calls▁end|>"); + + if (!builder.syntax().parse_tool_calls) { + LOG_DBG("%s: not parse_tool_calls\n", __func__); + builder.add_content(builder.consume_rest()); + return; + } + + LOG_DBG("%s: parse_tool_calls\n", __func__); + + parse_json_tool_calls( + builder, + /* block_open= */ tool_calls_begin, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + tool_calls_end); +} + +static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { + // DeepSeek V3.1 outputs reasoning content between "" and "" tags, followed by regular content + // First try to parse using the standard reasoning parsing method + LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); + + auto start_pos = builder.pos(); + auto found_end_think = builder.try_find_literal(""); + builder.move_to(start_pos); + + if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { + LOG_DBG("%s: no end_think, not partial, adding content\n", __func__); + common_chat_parse_deepseek_v3_1_content(builder); + } else if (builder.try_parse_reasoning("", "")) { + // If reasoning was parsed successfully, the remaining content is regular content + LOG_DBG("%s: parsed reasoning, adding content\n", __func__); + // <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|> + common_chat_parse_deepseek_v3_1_content(builder); + } else { + if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { + LOG_DBG("%s: reasoning_format none, adding content\n", __func__); + common_chat_parse_deepseek_v3_1_content(builder); + return; + } + // If no reasoning tags found, check if we should treat everything as reasoning + if (builder.syntax().thinking_forced_open) { + // If thinking is forced open but no tags found, treat everything as reasoning + LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__); + builder.add_reasoning_content(builder.consume_rest()); + } else { + LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__); + // <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|> + common_chat_parse_deepseek_v3_1_content(builder); + } + } +} + +static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.key_start = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = ""; + form.tool_start = "", ""); +} + +static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = "["; + form.tool_start = "{\"name\": \""; + form.tool_sep = "\", \"arguments\": {"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}, "; + form.scope_end = "]"; + form.raw_argval = false; + form.last_val_end = ""; + form.last_tool_end = "}"; + return form; + })(); + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = ""; + form.tool_start = "\n{\"name\": \""; + form.tool_sep = "\", \"arguments\": {"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}\n"; + form.scope_end = ""; + form.raw_argval = false; + form.last_val_end = ""; + return form; + })(); + builder.consume_reasoning_with_xml_tool_calls(form); +} + +static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { + static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))"; + static const std::string recipient("(?: to=functions\\.([^<\\s]+))"); + + static const common_regex start_regex("<\\|start\\|>assistant"); + static const common_regex analysis_regex("<\\|channel\\|>analysis"); + static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?"); + static const common_regex preamble_regex("<\\|channel\\|>commentary"); + static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?"); + static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?"); + + auto consume_end = [&](bool include_end = false) { + if (auto res = builder.try_find_literal("<|end|>")) { + return res->prelude + (include_end ? builder.str(res->groups[0]) : ""); + } + return builder.consume_rest(); + }; + + auto handle_tool_call = [&](const std::string & name) { + if (auto args = builder.try_consume_json_with_dumped_args({{}})) { + if (builder.syntax().parse_tool_calls) { + if (!builder.add_tool_call(name, "", args->value) || args->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } else if (args->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + }; + + auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional { + auto match = regex.search(input, 0, true); + if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) { + return match; + } + return std::nullopt; + }; + + do { + auto header_start_pos = builder.pos(); + auto content_start = builder.try_find_literal("<|message|>"); + if (!content_start) { + throw common_chat_msg_partial_exception("incomplete header"); + } + + auto header = content_start->prelude; + + if (auto match = regex_match(tool_call1_regex, header)) { + auto group = match->groups[1]; + auto name = header.substr(group.begin, group.end - group.begin); + handle_tool_call(name); + continue; + } + + if (auto match = regex_match(tool_call2_regex, header)) { + auto group = match->groups[2]; + auto name = header.substr(group.begin, group.end - group.begin); + handle_tool_call(name); + continue; + } + + if (regex_match(analysis_regex, header)) { + builder.move_to(header_start_pos); + if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) { + builder.add_content(consume_end(true)); + } else { + builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>"); + } + continue; + } + + if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) { + builder.add_content(consume_end()); + continue; + } + + // Possibly a malformed message, attempt to recover by rolling + // back to pick up the next <|start|> + LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str()); + builder.move_to(header_start_pos); + } while (builder.try_find_regex(start_regex, std::string::npos, false)); + + auto remaining = builder.consume_rest(); + if (!remaining.empty()) { + LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str()); + } +} + +static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.tool_sep = */ "", + /* form.key_start = */ "", + /* form.key_val_sep = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + /* form.key_val_sep2 = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const common_regex prefix(regex_escape(" functools[")); + parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); +} + +static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) { + static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))"); + static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))"); + static const common_regex close_regex(R"(\s*)"); + + parse_json_tool_calls( + builder, + std::nullopt, + function_regex_start_only, + function_regex, + close_regex, + std::nullopt, + /* allow_raw_python= */ true, + /* get_function_name= */ [&](const auto & res) -> std::string { + auto at_start = res.groups[0].begin == 0; + auto name = builder.str(res.groups[1]); + if (!name.empty() && name.back() == '{') { + // Unconsume the opening brace '{' to ensure the JSON parsing goes well. + builder.move_back(1); + } + auto idx = name.find_last_not_of("\n{"); + name = name.substr(0, idx + 1); + if (at_start && name == "all") { + return ""; + } + return name; + }); +} + +static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + // This version of Functionary still supports the llama 3.1 tool call format for the python tool. + static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); + + static const common_regex function_regex(R"()"); + static const common_regex close_regex(R"()"); + + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + std::nullopt); + + if (auto res = builder.try_find_regex(python_tag_regex)) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + builder.add_tool_call("python", "", arguments); + return; + } +} + +static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex open_regex( + "(?:" + "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) + "(" // match 2 (open_tag) + "" + "|" + "|" + "|" + "|" + "|" + "|" + "|" + ")?" + "(\\s*\\{\\s*\"name\")" // match 3 (named tool call) + ")" + "|]+)>" // match 4 (function name) + "|" // match 5 (function name again) + ); + + while (auto res = builder.try_find_regex(open_regex)) { + const auto & block_start = res->groups[1]; + std::string block_end = block_start.empty() ? "" : "```"; + + const auto & open_tag = res->groups[2]; + std::string close_tag; + + if (!res->groups[3].empty()) { + builder.move_to(res->groups[3].begin); + close_tag = open_tag.empty() ? "" : "value) || tool_call->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); + } + } else { + throw common_chat_msg_partial_exception("failed to parse tool call"); + } + } else { + auto function_name = builder.str(res->groups[4]); + if (function_name.empty()) { + function_name = builder.str(res->groups[5]); + } + GGML_ASSERT(!function_name.empty()); + + close_tag = ""; + + if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { + if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); + } + } + } + } + + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse_granite(common_chat_msg_parser & builder) { + // Parse thinking tags + static const common_regex start_think_regex(regex_escape("")); + static const common_regex end_think_regex(regex_escape("")); + // Granite models output partial tokens such as "<" and "groups[0].begin); + builder.try_find_regex(end_think_regex, std::string::npos, false); + // Restore position for try_parse_reasoning() + builder.move_to(res->groups[0].begin); + } + builder.try_parse_reasoning("", ""); + + // Parse response tags + static const common_regex start_response_regex(regex_escape("")); + static const common_regex end_response_regex(regex_escape("")); + // Granite models output partial tokens such as "<" and "")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + // Expect JSON array of tool calls + if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) { + if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) { + // Parse thinking tags + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Look for tool calls + static const common_regex tool_call_regex(regex_escape("")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + // Expect JSON array of tool calls + auto tool_calls_data = builder.consume_json(); + if (tool_calls_data.json.is_array()) { + if (!builder.try_consume_literal("")) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + builder.add_tool_calls(tool_calls_data.json); + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse_apertus(common_chat_msg_parser & builder) { + // Parse thinking tags + builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>"); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Look for tool calls + static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + auto tool_calls_data = builder.consume_json(); + if (tool_calls_data.json.is_array()) { + builder.consume_spaces(); + if (!builder.try_consume_literal("<|tools_suffix|>")) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + for (const auto & value : tool_calls_data.json) { + if (value.is_object()) { + builder.add_tool_call_short_form(value); + } + } + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + builder.add_content(builder.consume_rest()); +} + + +static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> + static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); + static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); + + // Loop through all tool calls + while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { + builder.move_to(res->groups[0].end); + + // Parse JSON array format: [{"name": "...", "arguments": {...}}] + auto tool_calls_data = builder.consume_json(); + + // Consume end marker + builder.consume_spaces(); + if (!builder.try_consume_regex(tool_call_end_regex)) { + throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); + } + + // Process each tool call in the array + if (tool_calls_data.json.is_array()) { + for (const auto & tool_call : tool_calls_data.json) { + if (!tool_call.is_object()) { + throw common_chat_msg_partial_exception("Tool call must be an object"); + } + + if (!tool_call.contains("name")) { + throw common_chat_msg_partial_exception("Tool call missing 'name' field"); + } + + std::string function_name = tool_call.at("name"); + std::string arguments = "{}"; + + if (tool_call.contains("arguments")) { + if (tool_call.at("arguments").is_object()) { + arguments = tool_call.at("arguments").dump(); + } else if (tool_call.at("arguments").is_string()) { + arguments = tool_call.at("arguments"); + } + } + + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + } else { + throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); + } + + // Consume any trailing whitespace after this tool call + builder.consume_spaces(); + } + + // Consume any remaining content after all tool calls + auto remaining = builder.consume_rest(); + if (!string_strip(remaining).empty()) { + builder.add_content(remaining); + } +} + +static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.key_start = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_content_only(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse(common_chat_msg_parser & builder) { + LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); + + switch (builder.syntax().format) { + case COMMON_CHAT_FORMAT_CONTENT_ONLY: + common_chat_parse_content_only(builder); + break; + case COMMON_CHAT_FORMAT_GENERIC: + common_chat_parse_generic(builder); + break; + case COMMON_CHAT_FORMAT_MISTRAL_NEMO: + common_chat_parse_mistral_nemo(builder); + break; + case COMMON_CHAT_FORMAT_MAGISTRAL: + common_chat_parse_magistral(builder); + break; + case COMMON_CHAT_FORMAT_LLAMA_3_X: + common_chat_parse_llama_3_1(builder); + break; + case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: + common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true); + break; + case COMMON_CHAT_FORMAT_DEEPSEEK_R1: + common_chat_parse_deepseek_r1(builder); + break; + case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: + common_chat_parse_deepseek_v3_1(builder); + break; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: + common_chat_parse_functionary_v3_2(builder); + break; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: + common_chat_parse_functionary_v3_1_llama_3_1(builder); + break; + case COMMON_CHAT_FORMAT_HERMES_2_PRO: + common_chat_parse_hermes_2_pro(builder); + break; + case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: + common_chat_parse_firefunction_v2(builder); + break; + case COMMON_CHAT_FORMAT_COMMAND_R7B: + common_chat_parse_command_r7b(builder); + break; + case COMMON_CHAT_FORMAT_GRANITE: + common_chat_parse_granite(builder); + break; + case COMMON_CHAT_FORMAT_GPT_OSS: + common_chat_parse_gpt_oss(builder); + break; + case COMMON_CHAT_FORMAT_SEED_OSS: + common_chat_parse_seed_oss(builder); + break; + case COMMON_CHAT_FORMAT_NEMOTRON_V2: + common_chat_parse_nemotron_v2(builder); + break; + case COMMON_CHAT_FORMAT_APERTUS: + common_chat_parse_apertus(builder); + break; + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: + common_chat_parse_lfm2(builder); + break; + case COMMON_CHAT_FORMAT_MINIMAX_M2: + common_chat_parse_minimax_m2(builder); + break; + case COMMON_CHAT_FORMAT_GLM_4_5: + common_chat_parse_glm_4_5(builder); + break; + case COMMON_CHAT_FORMAT_KIMI_K2: + common_chat_parse_kimi_k2(builder); + break; + case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: + common_chat_parse_qwen3_coder_xml(builder); + break; + case COMMON_CHAT_FORMAT_APRIEL_1_5: + common_chat_parse_apriel_1_5(builder); + break; + case COMMON_CHAT_FORMAT_XIAOMI_MIMO: + common_chat_parse_xiaomi_mimo(builder); + break; + default: + throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); + } + builder.finish(); +} + +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { + common_chat_msg_parser builder(input, is_partial, syntax); + try { + common_chat_parse(builder); + } catch (const common_chat_msg_partial_exception & ex) { + LOG_DBG("Partial parse: %s\n", ex.what()); + if (!is_partial) { + builder.clear_tools(); + builder.move_to(0); + common_chat_parse_content_only(builder); + } + } + auto msg = builder.result(); + if (!is_partial) { + LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); + } + return msg; +} diff --git a/common/chat-parser.h b/common/chat-parser.h index 824982b4..cb70d16e 100644 --- a/common/chat-parser.h +++ b/common/chat-parser.h @@ -63,6 +63,9 @@ class common_chat_msg_parser { // Adds an array of tool calls using their "name", "id" and "arguments" fields. bool add_tool_calls(const nlohmann::ordered_json & arr); + // Adds a tool call using the short form: { "tool_name": { "arg1": val, "arg2": val } } + bool add_tool_call_short_form(const nlohmann::ordered_json& tool_call); + void finish(); bool consume_spaces(); diff --git a/common/chat.cpp b/common/chat.cpp index 0ae044be..21aa524e 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -627,6 +627,7 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only"; case COMMON_CHAT_FORMAT_GENERIC: return "Generic"; case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo"; + case COMMON_CHAT_FORMAT_MAGISTRAL: return "Magistral"; case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x"; case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools"; case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1"; @@ -638,6 +639,10 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; case COMMON_CHAT_FORMAT_GRANITE: return "Granite"; case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS"; + case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS"; + case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2"; + case COMMON_CHAT_FORMAT_APERTUS: return "Apertus"; + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools"; case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2"; case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5"; case COMMON_CHAT_FORMAT_KIMI_K2: return "Kimi K2"; @@ -676,114 +681,6 @@ common_reasoning_format common_reasoning_format_from_name(const std::string& for throw std::runtime_error("Unknown reasoning format: " + format); } -static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) { - std::string arguments; - if (builder.is_partial()) { - arguments = (json {{"code", code + builder.healing_marker()}}).dump(); - auto idx = arguments.find(builder.healing_marker()); - if (idx != std::string::npos) { - arguments.resize(idx); - } - } else { - arguments = (json {{"code", code}}).dump(); - } - return arguments; -} - -/** - * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. - * Aggregates the prefix, suffix and in-between text into the content. - */ -static void parse_json_tool_calls( - common_chat_msg_parser & builder, - const std::optional & block_open, - const std::optional & function_regex_start_only, - const std::optional & function_regex, - const common_regex & close_regex, - const std::optional & block_close, - bool allow_raw_python = false, - const std::function & get_function_name = nullptr) { - - auto parse_tool_calls = [&]() { - size_t from = std::string::npos; - auto first = true; - while (true) { - auto start_pos = builder.pos(); - auto res = function_regex_start_only && first - ? builder.try_consume_regex(*function_regex_start_only) - : function_regex - ? builder.try_find_regex(*function_regex, from) - : std::nullopt; - - if (res) { - std::string name; - if (get_function_name) { - name = get_function_name(*res); - } else { - GGML_ASSERT(res->groups.size() == 2); - name = builder.str(res->groups[1]); - } - first = false; - if (name.empty()) { - // get_function_name signalled us that we should skip this match and treat it as content. - from = res->groups[0].begin + 1; - continue; - } - from = std::string::npos; - - auto maybe_raw_python = name == "python" && allow_raw_python; - if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { - if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { - if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_regex(close_regex); - } - continue; - } - if (maybe_raw_python) { - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - if (!builder.add_tool_call(name, "", arguments)) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - return; - } - throw common_chat_msg_partial_exception("incomplete tool call"); - } else { - builder.move_to(start_pos); - } - break; - } - if (block_close) { - builder.consume_regex(*block_close); - } - builder.consume_spaces(); - builder.add_content(builder.consume_rest()); - }; - if (block_open) { - if (auto res = builder.try_find_regex(*block_open)) { - parse_tool_calls(); - } else { - builder.add_content(builder.consume_rest()); - } - } else { - parse_tool_calls(); - } -} - -static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) { - static const std::vector> args_paths = {{"arguments"}}; - if (auto res = builder.try_find_regex(prefix)) { - builder.move_back(rstrip_prefix); - auto tool_calls = builder.consume_json_with_dumped_args(args_paths); - if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call array"); - } - } else { - builder.add_content(builder.consume_rest()); - } -} - static void foreach_function(const json & tools, const std::function & fn) { for (const auto & tool : tools) { if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) { @@ -915,37 +812,6 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp data.format = COMMON_CHAT_FORMAT_GENERIC; return data; } -static void common_chat_parse_generic(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - static const std::vector> content_paths = { - {"response"}, - }; - static const std::vector> args_paths = { - {"tool_call", "arguments"}, - {"tool_calls", "arguments"}, - }; - auto data = builder.consume_json_with_dumped_args(args_paths, content_paths); - if (data.value.contains("tool_calls")) { - if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool calls"); - } - } else if (data.value.contains("tool_call")) { - if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } else if (data.value.contains("response")) { - const auto & response = data.value.at("response"); - builder.add_content(response.is_string() ? response.template get() : response.dump(2)); - if (data.is_partial) { - throw common_chat_msg_partial_exception("incomplete response"); - } - } else { - throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON"); - } -} static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -991,16 +857,6 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO; return data; } -static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex prefix(regex_escape("[TOOL_CALLS]")); - parse_prefixed_json_tool_call_array(builder, prefix); -} - static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -1081,39 +937,6 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ return data; } -static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); - - static const common_regex start_action_regex("<\\|START_ACTION\\|>"); - static const common_regex end_action_regex("<\\|END_ACTION\\|>"); - static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); - static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); - - if (auto res = builder.try_find_regex(start_action_regex)) { - // If we didn't extract thoughts, prelude includes them. - auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); - for (const auto & tool_call : tool_calls.value) { - std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; - std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; - std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; - if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - if (tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_regex(end_action_regex); - } else if (auto res = builder.try_find_regex(start_response_regex)) { - if (!builder.try_find_regex(end_response_regex)) { - builder.add_content(builder.consume_rest()); - throw common_chat_msg_partial_exception(end_response_regex.str()); - } - } else { - builder.add_content(builder.consume_rest()); - } -} - static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) { throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); @@ -1212,63 +1035,6 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te }); return data; } -static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { - builder.try_parse_reasoning("", ""); - - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex function_regex( - "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); - static const common_regex close_regex("\\}\\s*"); - - static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); - static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); - - if (with_builtin_tools) { - static const common_regex builtin_call_regex("<\\|python_tag\\|>"); - if (auto res = builder.try_find_regex(builtin_call_regex)) { - auto fun_res = builder.consume_regex(function_name_regex); - auto function_name = builder.str(fun_res.groups[1]); - - common_healing_marker healing_marker; - json args = json::object(); - while (true) { - if (auto arg_res = builder.try_consume_regex(arg_name_regex)) { - auto arg_name = builder.str(arg_res->groups[1]); - auto partial = builder.consume_json(); - args[arg_name] = partial.json; - healing_marker.marker = partial.healing_marker.marker; - healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker; - builder.consume_spaces(); - if (!builder.try_consume_literal(",")) { - break; - } - } else { - break; - } - } - builder.consume_literal(")"); - builder.consume_spaces(); - - auto arguments = args.dump(); - if (!builder.add_tool_call(function_name, "", arguments)) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - return; - } - } - parse_json_tool_calls( - builder, - /* block_open= */ std::nullopt, - /* function_regex_start_only= */ function_regex, - /* function_regex= */ std::nullopt, - close_regex, - std::nullopt); - -} static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -1408,88 +1174,6 @@ static common_chat_params common_chat_params_init_deepseek_v3_1(const common_cha return data; } -static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); - static const common_regex tool_calls_end("<|tool▁calls▁end|>"); - static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n"); - static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); - - parse_json_tool_calls( - builder, - /* block_open= */ tool_calls_begin, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - tool_calls_end); -} - -static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) { - static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)"); - - static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>"); - static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); - static const common_regex tool_calls_end("<|tool▁calls▁end|>"); - - if (!builder.syntax().parse_tool_calls) { - LOG("%s: not parse_tool_calls\n", __func__); - builder.add_content(builder.consume_rest()); - return; - } - - LOG("%s: parse_tool_calls\n", __func__); - - parse_json_tool_calls( - builder, - /* block_open= */ tool_calls_begin, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - tool_calls_end); -} - -static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { - // DeepSeek V3.1 outputs reasoning content between "" and "" tags, followed by regular content - // First try to parse using the standard reasoning parsing method - LOG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); - - auto start_pos = builder.pos(); - auto found_end_think = builder.try_find_literal(""); - builder.move_to(start_pos); - - if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { - LOG("%s: no end_think, not partial, adding content\n", __func__); - common_chat_parse_deepseek_v3_1_content(builder); - } else if (builder.try_parse_reasoning("", "")) { - // If reasoning was parsed successfully, the remaining content is regular content - LOG("%s: parsed reasoning, adding content\n", __func__); - // <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|> - common_chat_parse_deepseek_v3_1_content(builder); - } else { - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { - LOG("%s: reasoning_format none, adding content\n", __func__); - common_chat_parse_deepseek_v3_1_content(builder); - return; - } - // If no reasoning tags found, check if we should treat everything as reasoning - if (builder.syntax().thinking_forced_open) { - // If thinking is forced open but no tags found, treat everything as reasoning - LOG("%s: thinking_forced_open, adding reasoning content\n", __func__); - builder.add_reasoning_content(builder.consume_rest()); - } else { - LOG("%s: no thinking_forced_open, adding content\n", __func__); - // <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|> - common_chat_parse_deepseek_v3_1_content(builder); - } - } -} - - static common_chat_params common_chat_params_init_minimax_m2(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -1532,20 +1216,6 @@ static common_chat_params common_chat_params_init_minimax_m2(const common_chat_t return data; } -static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.key_start = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -1578,23 +1248,6 @@ static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_c return data; } -static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = ""; - form.tool_start = "", ""); -} - static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -1693,25 +1327,6 @@ static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_t return data; } -static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = "["; - form.tool_start = "{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}, "; - form.scope_end = "]"; - form.raw_argval = false; - form.last_val_end = ""; - form.last_tool_end = "}"; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -1744,24 +1359,6 @@ static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_ return data; } -static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = ""; - form.tool_start = "\n{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}\n"; - form.scope_end = ""; - form.raw_argval = false; - form.last_val_end = ""; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form); -} - static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; auto prompt = apply(tmpl, inputs); @@ -1892,93 +1489,6 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp return data; } -static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { - static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))"; - static const std::string recipient("(?: to=functions\\.([^<\\s]+))"); - - static const common_regex start_regex("<\\|start\\|>assistant"); - static const common_regex analysis_regex("<\\|channel\\|>analysis"); - static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?"); - static const common_regex preamble_regex("<\\|channel\\|>commentary"); - static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?"); - static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?"); - - auto consume_end = [&](bool include_end = false) { - if (auto res = builder.try_find_literal("<|end|>")) { - return res->prelude + (include_end ? builder.str(res->groups[0]) : ""); - } - return builder.consume_rest(); - }; - - auto handle_tool_call = [&](const std::string & name) { - if (auto args = builder.try_consume_json_with_dumped_args({{}})) { - if (builder.syntax().parse_tool_calls) { - if (!builder.add_tool_call(name, "", args->value) || args->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } else if (args->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - }; - - auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional { - auto match = regex.search(input, 0, true); - if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) { - return match; - } - return std::nullopt; - }; - - do { - auto header_start_pos = builder.pos(); - auto content_start = builder.try_find_literal("<|message|>"); - if (!content_start) { - throw common_chat_msg_partial_exception("incomplete header"); - } - - auto header = content_start->prelude; - - if (auto match = regex_match(tool_call1_regex, header)) { - auto group = match->groups[1]; - auto name = header.substr(group.begin, group.end - group.begin); - handle_tool_call(name); - continue; - } - - if (auto match = regex_match(tool_call2_regex, header)) { - auto group = match->groups[2]; - auto name = header.substr(group.begin, group.end - group.begin); - handle_tool_call(name); - continue; - } - - if (regex_match(analysis_regex, header)) { - builder.move_to(header_start_pos); - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) { - builder.add_content(consume_end(true)); - } else { - builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>"); - } - continue; - } - - if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) { - builder.add_content(consume_end()); - continue; - } - - // Possibly a malformed message, attempt to recover by rolling - // back to pick up the next <|start|> - LOG("%s: unknown header from message: %s\n", __func__, header.c_str()); - builder.move_to(header_start_pos); - } while (builder.try_find_regex(start_regex, std::string::npos, false)); - - auto remaining = builder.consume_rest(); - if (!remaining.empty()) { - LOG("%s: content after last message: %s\n", __func__, remaining.c_str()); - } -} static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2059,21 +1569,6 @@ static common_chat_params common_chat_params_init_glm_4_5(const common_chat_temp return data; } -static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.tool_sep = */ "", - /* form.key_start = */ "", - /* form.key_val_sep = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - /* form.key_val_sep2 = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { LOG("%s\n", __func__); common_chat_params data; @@ -2119,14 +1614,6 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c } return data; } -static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - static const common_regex prefix(regex_escape(" functools[")); - parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); -} static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) { // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... @@ -2177,34 +1664,6 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ } return data; } -static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) { - static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))"); - static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))"); - static const common_regex close_regex(R"(\s*)"); - - parse_json_tool_calls( - builder, - std::nullopt, - function_regex_start_only, - function_regex, - close_regex, - std::nullopt, - /* allow_raw_python= */ true, - /* get_function_name= */ [&](const auto & res) -> std::string { - auto at_start = res.groups[0].begin == 0; - auto name = builder.str(res.groups[1]); - if (!name.empty() && name.back() == '{') { - // Unconsume the opening brace '{' to ensure the JSON parsing goes well. - builder.move_back(1); - } - auto idx = name.find_last_not_of("\n{"); - name = name.substr(0, idx + 1); - if (at_start && name == "all") { - return ""; - } - return name; - }); -} static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt @@ -2264,31 +1723,6 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con // TODO: if (has_raw_python) return data; } -static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - // This version of Functionary still supports the llama 3.1 tool call format for the python tool. - static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); - - static const common_regex function_regex(R"()"); - static const common_regex close_regex(R"()"); - - parse_json_tool_calls( - builder, - /* block_open= */ std::nullopt, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - std::nullopt); - - if (auto res = builder.try_find_regex(python_tag_regex)) { - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - builder.add_tool_call("python", "", arguments); - return; - } -} static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2405,83 +1839,6 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat return data; } -static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex open_regex( - "(?:" - "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) - "(" // match 2 (open_tag) - "" - "|" - "|" - "|" - "|" - "|" - "|" - "|" - ")?" - "(\\s*\\{\\s*\"name\")" // match 3 (named tool call) - ")" - "|]+)>" // match 4 (function name) - "|" // match 5 (function name again) - ); - - while (auto res = builder.try_find_regex(open_regex)) { - const auto & block_start = res->groups[1]; - std::string block_end = block_start.empty() ? "" : "```"; - - const auto & open_tag = res->groups[2]; - std::string close_tag; - - if (!res->groups[3].empty()) { - builder.move_to(res->groups[3].begin); - close_tag = open_tag.empty() ? "" : "value) || tool_call->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - builder.consume_literal(close_tag); - builder.consume_spaces(); - if (!block_end.empty()) { - builder.consume_literal(block_end); - builder.consume_spaces(); - } - } else { - throw common_chat_msg_partial_exception("failed to parse tool call"); - } - } else { - auto function_name = builder.str(res->groups[4]); - if (function_name.empty()) { - function_name = builder.str(res->groups[5]); - } - GGML_ASSERT(!function_name.empty()); - - close_tag = ""; - - if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { - if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - builder.consume_literal(close_tag); - builder.consume_spaces(); - if (!block_end.empty()) { - builder.consume_literal(block_end); - builder.consume_spaces(); - } - } - } - } - - builder.add_content(builder.consume_rest()); -} static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2564,53 +1921,6 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp return data; } -static void common_chat_parse_granite(common_chat_msg_parser & builder) { - // Parse thinking tags - static const common_regex start_think_regex(regex_escape("")); - static const common_regex end_think_regex(regex_escape("")); - // Granite models output partial tokens such as "<" and "groups[0].begin); - builder.try_find_regex(end_think_regex, std::string::npos, false); - // Restore position for try_parse_reasoning() - builder.move_to(res->groups[0].begin); - } - builder.try_parse_reasoning("", ""); - - // Parse response tags - static const common_regex start_response_regex(regex_escape("")); - static const common_regex end_response_regex(regex_escape("")); - // Granite models output partial tokens such as "<" and "")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - // Expect JSON array of tool calls - if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) { - if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - } else { - builder.add_content(builder.consume_rest()); - } -} - static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs); @@ -2802,7 +2112,7 @@ static common_chat_params common_chat_templates_apply_legacy( const struct common_chat_templates * tmpls, const struct common_chat_templates_inputs & inputs) { - int alloc_size = 0; + size_t alloc_size = 0; std::vector chat; std::vector contents; @@ -2824,7 +2134,8 @@ static common_chat_params common_chat_templates_apply_legacy( const auto & msg = inputs.messages[i]; const auto & content = contents[i]; chat.push_back({msg.role.c_str(), content.c_str()}); - alloc_size += (msg.role.size() + content.size()) * 1.25; + size_t msg_size = msg.role.size() + content.size(); + alloc_size += msg_size + (msg_size / 4); // == msg_size * 1.25 but avoiding float ops } std::vector buf(alloc_size); @@ -2846,6 +2157,11 @@ static common_chat_params common_chat_templates_apply_legacy( res = llama_chat_apply_template(nullptr, src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); } + // for safety, we check the result again + if (res < 0 || (size_t) res > buf.size()) { + throw std::runtime_error("failed to apply chat template, try using --jinja"); + } + common_chat_params params; params.prompt = std::string(buf.data(), res); if (!inputs.json_schema.empty()) { @@ -2865,97 +2181,3 @@ common_chat_params common_chat_templates_apply( ? common_chat_templates_apply_jinja(tmpls, inputs) : common_chat_templates_apply_legacy(tmpls, inputs); } - -static void common_chat_parse_content_only(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse(common_chat_msg_parser & builder) { - LOG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); - - switch (builder.syntax().format) { - case COMMON_CHAT_FORMAT_CONTENT_ONLY: - common_chat_parse_content_only(builder); - break; - case COMMON_CHAT_FORMAT_GENERIC: - common_chat_parse_generic(builder); - break; - case COMMON_CHAT_FORMAT_MISTRAL_NEMO: - common_chat_parse_mistral_nemo(builder); - break; - case COMMON_CHAT_FORMAT_LLAMA_3_X: - common_chat_parse_llama_3_1(builder); - break; - case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: - common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true); - break; - case COMMON_CHAT_FORMAT_DEEPSEEK_R1: - common_chat_parse_deepseek_r1(builder); - break; - case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: - common_chat_parse_deepseek_v3_1(builder); - break; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: - common_chat_parse_functionary_v3_2(builder); - break; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: - common_chat_parse_functionary_v3_1_llama_3_1(builder); - break; - case COMMON_CHAT_FORMAT_HERMES_2_PRO: - common_chat_parse_hermes_2_pro(builder); - break; - case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: - common_chat_parse_firefunction_v2(builder); - break; - case COMMON_CHAT_FORMAT_COMMAND_R7B: - common_chat_parse_command_r7b(builder); - break; - case COMMON_CHAT_FORMAT_GRANITE: - common_chat_parse_granite(builder); - break; - case COMMON_CHAT_FORMAT_GPT_OSS: - common_chat_parse_gpt_oss(builder); - break; - case COMMON_CHAT_FORMAT_MINIMAX_M2: - common_chat_parse_minimax_m2(builder); - break; - case COMMON_CHAT_FORMAT_GLM_4_5: - common_chat_parse_glm_4_5(builder); - break; - case COMMON_CHAT_FORMAT_KIMI_K2: - common_chat_parse_kimi_k2(builder); - break; - case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: - common_chat_parse_qwen3_coder_xml(builder); - break; - case COMMON_CHAT_FORMAT_APRIEL_1_5: - common_chat_parse_apriel_1_5(builder); - break; - case COMMON_CHAT_FORMAT_XIAOMI_MIMO: - common_chat_parse_xiaomi_mimo(builder); - break; - default: - throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); - } - builder.finish(); -} - -common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { - common_chat_msg_parser builder(input, is_partial, syntax); - try { - common_chat_parse(builder); - } catch (const common_chat_msg_partial_exception & ex) { - LOG("Partial parse: %s\n", ex.what()); - if (!is_partial) { - builder.clear_tools(); - builder.move_to(0); - common_chat_parse_content_only(builder); - } - } - auto msg = builder.result(); - if (!is_partial) { - LOG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); - } - return msg; -} diff --git a/common/chat.h b/common/chat.h index cdea627a..7c234019 100644 --- a/common/chat.h +++ b/common/chat.h @@ -101,6 +101,7 @@ enum common_chat_format { COMMON_CHAT_FORMAT_CONTENT_ONLY, COMMON_CHAT_FORMAT_GENERIC, COMMON_CHAT_FORMAT_MISTRAL_NEMO, + COMMON_CHAT_FORMAT_MAGISTRAL, COMMON_CHAT_FORMAT_LLAMA_3_X, COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, COMMON_CHAT_FORMAT_DEEPSEEK_R1, @@ -112,6 +113,10 @@ enum common_chat_format { COMMON_CHAT_FORMAT_COMMAND_R7B, COMMON_CHAT_FORMAT_GRANITE, COMMON_CHAT_FORMAT_GPT_OSS, + COMMON_CHAT_FORMAT_SEED_OSS, + COMMON_CHAT_FORMAT_NEMOTRON_V2, + COMMON_CHAT_FORMAT_APERTUS, + COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS, COMMON_CHAT_FORMAT_GLM_4_5, COMMON_CHAT_FORMAT_MINIMAX_M2, COMMON_CHAT_FORMAT_KIMI_K2, diff --git a/common/common.cpp b/common/common.cpp index e695375c..44290888 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2726,11 +2726,29 @@ bool fs_validate_filename(const std::string & filename) { return true; } +#ifdef _WIN32 +static std::wstring utf8_to_wstring(const std::string& str) { + if (str.empty()) { + return std::wstring(); + } + + int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0); + + if (size <= 0) { + return std::wstring(); + } + + std::wstring wstr(size, 0); + MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size); + + return wstr; +} +#endif + // returns true if successful, false otherwise bool fs_create_directory_with_parents(const std::string & path) { #ifdef _WIN32 - std::wstring_convert> converter; - std::wstring wpath = converter.from_bytes(path); + std::wstring wpath = utf8_to_wstring(path); // if the path already exists, check whether it's a directory const DWORD attributes = GetFileAttributesW(wpath.c_str()); @@ -3586,175 +3604,6 @@ bool llama_should_add_bos_token(const llama_model * model) { return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); } -// -// Chat template utils -// -// -//bool llama_chat_verify_template(const struct llama_model* model, const std::string& tmpl, bool use_jinja) { -// if (use_jinja) { -// try { -// auto chat_template = common_chat_template(tmpl, "", ""); -// common_chat_inputs inputs; -// inputs.messages = json::array({ { -// {"role", "user"}, -// {"content", "test"}, -// } }); -// common_chat_params_init(chat_template, inputs); -// return true; -// } -// catch (const std::exception& e) { -// fprintf(stdout,"%s: failed to apply template: %s\n", __func__, e.what()); -// return false; -// } -// } -// llama_chat_message chat[] = { {"user", "test"} }; -// const int res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0); -// return res >= 0; -//} - -//std::string llama_chat_apply_template(const struct llama_model * model, -// const common_chat_template& tmpl, -// const std::vector & msgs, -// bool add_ass, -// bool use_jinja) { -// if (use_jinja) { -// auto messages = json::array(); -// for (const auto& msg : msgs) { -// messages.push_back({ {"role", msg.role}, {"content", msg.content} }); -// } -// common_chat_inputs inputs; -// inputs.messages = messages; -// inputs.add_generation_prompt = add_ass; -// return common_chat_params_init(tmpl, inputs).prompt; -// } -// int alloc_size = 0; -// std::vector chat; -// for (auto & msg : msgs) { -// chat.push_back({msg.role.c_str(), msg.content.c_str()}); -// alloc_size += (msg.role.size() + msg.content.size()) * 1.25; -// } -// -// std::vector buf(alloc_size); -// -// // run the first time to get the total output length -// int32_t res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); -// // error: chat template is not supported -// if (res < 0) { -// // if the custom "tmpl" is not supported, we throw an error -// // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() -// throw std::runtime_error("this custom template is not supported"); -// } -// -// // if it turns out that our buffer is too small, we resize it -// if ((size_t)res > buf.size()) { -// buf.resize(res); -// res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); -// } -// -// std::string formatted_chat(buf.data(), res); -// return formatted_chat; -//} -//// -//std::string llama_chat_format_single(const struct llama_model * model, -// const common_chat_template& tmpl, -// const std::vector & past_msg, -// const common_chat_msg & new_msg, -// bool add_ass, -// bool use_jinja) { -// std::ostringstream ss; -// auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja); -// std::vector chat_new(past_msg); -// // if the past_msg ends with a newline, we must preserve it in the formatted version -// if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { -// ss << "\n"; -// }; -// // format chat with new_msg -// chat_new.push_back(new_msg); -// auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja); -// // get the diff part -// ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); -// return ss.str(); -//} - -//std::string llama_chat_format_example(const struct llama_model * model, const common_chat_template& tmpl, bool use_jinja) { -// std::vector msgs = { -// {"system", "You are a helpful assistant", {}}, -// {"user", "Hello", {}}, -// {"assistant", "Hi there", {}}, -// {"user", "How are you?", {}}, -// }; -// return llama_chat_apply_template(model, tmpl, msgs, true, use_jinja); -//} -// -//#define CHATML_TEMPLATE_SRC \ -// "{%- for message in messages -%}\n" \ -// " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ -// "{%- endfor -%}\n" \ -// "{%- if add_generation_prompt -%}\n" \ -// " {{- '<|im_start|>assistant\n' -}}\n" \ -// "{%- endif -%}" -// -//common_chat_templates llama_chat_templates_from_model(const struct llama_model* model, const std::string& chat_template_override) -//{ -// std::string default_template_src; -// std::string template_tool_use_src; -// bool has_explicit_template = !chat_template_override.empty(); -// if (chat_template_override.empty()) { -// auto str = llama_model_chat_template(model, /* name */ nullptr); -// if (str) { -// default_template_src = str; -// has_explicit_template = true; -// } -// str = llama_model_chat_template(model, /* name */ "tool_use"); -// if (str) { -// template_tool_use_src = str; -// has_explicit_template = true; -// } -// } -// else { -// default_template_src = chat_template_override; -// } -// if (default_template_src.empty() || default_template_src == "chatml") { -// if (!template_tool_use_src.empty()) { -// default_template_src = template_tool_use_src; -// } -// else { -// default_template_src = CHATML_TEMPLATE_SRC; -// } -// } -// auto vocab = llama_model_get_vocab(model); -// const auto get_token = [&](llama_token token, const char* name, const char* jinja_variable_name) { -// if (token == LLAMA_TOKEN_NULL) { -// if (default_template_src.find(jinja_variable_name) != std::string::npos -// || template_tool_use_src.find(jinja_variable_name) != std::string::npos) { -// fprintf(stdout, "%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name); -// } -// return std::string(); -// } -// else { -// return llama_token_to_piece(model, token, true); -// } -// }; -// auto token_bos = get_token(llama_token_bos_impl(*vocab), "BOS", "bos_token"); -// auto token_eos = get_token(llama_token_eos_impl(*vocab), "EOS", "eos_token"); -// try { -// return { -// has_explicit_template, -// std::make_unique(default_template_src, token_bos, token_eos), -// template_tool_use_src.empty() -// ? nullptr -// : std::make_unique(template_tool_use_src, token_bos, token_eos), -// }; -// } -// catch (const std::exception& e) { -// LOG("%s: failed to parse chat template: %s\n", __func__, e.what()); -// return { -// has_explicit_template, -// std::make_unique(CHATML_TEMPLATE_SRC, token_bos, token_eos), -// nullptr, -// }; -// } -//} // // KV cache utils diff --git a/common/log.cpp b/common/log.cpp new file mode 100644 index 00000000..b6c9ff79 --- /dev/null +++ b/common/log.cpp @@ -0,0 +1,464 @@ +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_WIN32) +# include +# include +# define isatty _isatty +# define fileno _fileno +#else +# include +#endif // defined(_WIN32) + +int common_log_verbosity_thold = LOG_DEFAULT_LLAMA; + +void common_log_set_verbosity_thold(int verbosity) { + common_log_verbosity_thold = verbosity; +} + +// Auto-detect if colors should be enabled based on terminal and environment +static bool common_log_should_use_colors_auto() { + // Check NO_COLOR environment variable (https://no-color.org/) + if (const char * no_color = std::getenv("NO_COLOR")) { + if (no_color[0] != '\0') { + return false; + } + } + + // Check TERM environment variable + if (const char * term = std::getenv("TERM")) { + if (std::strcmp(term, "dumb") == 0) { + return false; + } + } + + // Check if stdout and stderr are connected to a terminal + // We check both because log messages can go to either + bool stdout_is_tty = isatty(fileno(stdout)); + bool stderr_is_tty = isatty(fileno(stderr)); + + return stdout_is_tty || stderr_is_tty; +} + +static int64_t t_us() { + return std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); +} + +// colors +enum common_log_col : int { + COMMON_LOG_COL_DEFAULT = 0, + COMMON_LOG_COL_BOLD, + COMMON_LOG_COL_RED, + COMMON_LOG_COL_GREEN, + COMMON_LOG_COL_YELLOW, + COMMON_LOG_COL_BLUE, + COMMON_LOG_COL_MAGENTA, + COMMON_LOG_COL_CYAN, + COMMON_LOG_COL_WHITE, +}; + +// disable colors by default +static std::vector g_col = { + "", + "", + "", + "", + "", + "", + "", + "", + "", +}; + +struct common_log_entry { + enum ggml_log_level level; + + bool prefix; + + int64_t timestamp; + + std::vector msg; + + // signals the worker thread to stop + bool is_end; + + void print(FILE * file = nullptr) const { + FILE * fcur = file; + if (!fcur) { + // stderr displays DBG messages only when their verbosity level is not higher than the threshold + // these messages will still be logged to a file + if (level == GGML_LOG_LEVEL_DEBUG && common_log_verbosity_thold < LOG_DEFAULT_DEBUG) { + return; + } + + fcur = stdout; + + if (level != GGML_LOG_LEVEL_NONE) { + fcur = stderr; + } + } + + if (level != GGML_LOG_LEVEL_NONE && level != GGML_LOG_LEVEL_CONT && prefix) { + if (timestamp) { + // [M.s.ms.us] + fprintf(fcur, "%s%d.%02d.%03d.%03d%s ", + g_col[COMMON_LOG_COL_BLUE], + (int) (timestamp / 1000000 / 60), + (int) (timestamp / 1000000 % 60), + (int) (timestamp / 1000 % 1000), + (int) (timestamp % 1000), + g_col[COMMON_LOG_COL_DEFAULT]); + } + + switch (level) { + case GGML_LOG_LEVEL_INFO: fprintf(fcur, "%sI %s", g_col[COMMON_LOG_COL_GREEN], g_col[COMMON_LOG_COL_DEFAULT]); break; + case GGML_LOG_LEVEL_WARN: fprintf(fcur, "%sW %s", g_col[COMMON_LOG_COL_MAGENTA], "" ); break; + case GGML_LOG_LEVEL_ERROR: fprintf(fcur, "%sE %s", g_col[COMMON_LOG_COL_RED], "" ); break; + case GGML_LOG_LEVEL_DEBUG: fprintf(fcur, "%sD %s", g_col[COMMON_LOG_COL_YELLOW], "" ); break; + default: + break; + } + } + + fprintf(fcur, "%s", msg.data()); + + if (level == GGML_LOG_LEVEL_WARN || level == GGML_LOG_LEVEL_ERROR || level == GGML_LOG_LEVEL_DEBUG) { + fprintf(fcur, "%s", g_col[COMMON_LOG_COL_DEFAULT]); + } + + fflush(fcur); + } +}; + +struct common_log { + // default capacity - will be expanded if needed + common_log() : common_log(256) {} + + common_log(size_t capacity) { + file = nullptr; + prefix = false; + timestamps = false; + running = false; + t_start = t_us(); + + // initial message size - will be expanded if longer messages arrive + entries.resize(capacity); + for (auto & entry : entries) { + entry.msg.resize(256); + } + + head = 0; + tail = 0; + + resume(); + } + + ~common_log() { + pause(); + if (file) { + fclose(file); + } + } + +private: + std::mutex mtx; + std::thread thrd; + std::condition_variable cv; + + FILE * file; + + bool prefix; + bool timestamps; + bool running; + + int64_t t_start; + + // ring buffer of entries + std::vector entries; + size_t head; + size_t tail; + + // worker thread copies into this + common_log_entry cur; + +public: + void add(enum ggml_log_level level, const char * fmt, va_list args) { + std::lock_guard lock(mtx); + + if (!running) { + // discard messages while the worker thread is paused + return; + } + + auto & entry = entries[tail]; + + { + // cannot use args twice, so make a copy in case we need to expand the buffer + va_list args_copy; + va_copy(args_copy, args); + +#if 1 + const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args); + if (n >= entry.msg.size()) { + entry.msg.resize(n + 1); + vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args_copy); + } +#else + // hack for bolding arguments + + std::stringstream ss; + for (int i = 0; fmt[i] != 0; i++) { + if (fmt[i] == '%') { + ss << LOG_COL_BOLD; + while (fmt[i] != ' ' && fmt[i] != ')' && fmt[i] != ']' && fmt[i] != 0) ss << fmt[i++]; + ss << LOG_COL_DEFAULT; + if (fmt[i] == 0) break; + } + ss << fmt[i]; + } + const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args); + if (n >= entry.msg.size()) { + entry.msg.resize(n + 1); + vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy); + } +#endif + va_end(args_copy); + } + + entry.level = level; + entry.prefix = prefix; + entry.timestamp = 0; + if (timestamps) { + entry.timestamp = t_us() - t_start; + } + entry.is_end = false; + + tail = (tail + 1) % entries.size(); + if (tail == head) { + // expand the buffer + std::vector new_entries(2*entries.size()); + + size_t new_tail = 0; + + do { + new_entries[new_tail] = std::move(entries[head]); + + head = (head + 1) % entries.size(); + new_tail = (new_tail + 1); + } while (head != tail); + + head = 0; + tail = new_tail; + + for (size_t i = tail; i < new_entries.size(); i++) { + new_entries[i].msg.resize(256); + } + + entries = std::move(new_entries); + } + + cv.notify_one(); + } + + void resume() { + std::lock_guard lock(mtx); + + if (running) { + return; + } + + running = true; + + thrd = std::thread([this]() { + while (true) { + { + std::unique_lock lock(mtx); + cv.wait(lock, [this]() { return head != tail; }); + + cur = entries[head]; + + head = (head + 1) % entries.size(); + } + + if (cur.is_end) { + break; + } + + cur.print(); // stdout and stderr + + if (file) { + cur.print(file); + } + } + }); + } + + void pause() { + { + std::lock_guard lock(mtx); + + if (!running) { + return; + } + + running = false; + + // push an entry to signal the worker thread to stop + { + auto & entry = entries[tail]; + entry.is_end = true; + + tail = (tail + 1) % entries.size(); + } + + cv.notify_one(); + } + + thrd.join(); + } + + void set_file(const char * path) { + pause(); + + if (file) { + fclose(file); + } + + if (path) { + file = fopen(path, "w"); + } else { + file = nullptr; + } + + resume(); + } + + void set_colors(bool colors) { + pause(); + + if (colors) { + g_col[COMMON_LOG_COL_DEFAULT] = LOG_COL_DEFAULT; + g_col[COMMON_LOG_COL_BOLD] = LOG_COL_BOLD; + g_col[COMMON_LOG_COL_RED] = LOG_COL_RED; + g_col[COMMON_LOG_COL_GREEN] = LOG_COL_GREEN; + g_col[COMMON_LOG_COL_YELLOW] = LOG_COL_YELLOW; + g_col[COMMON_LOG_COL_BLUE] = LOG_COL_BLUE; + g_col[COMMON_LOG_COL_MAGENTA] = LOG_COL_MAGENTA; + g_col[COMMON_LOG_COL_CYAN] = LOG_COL_CYAN; + g_col[COMMON_LOG_COL_WHITE] = LOG_COL_WHITE; + } else { + for (size_t i = 0; i < g_col.size(); i++) { + g_col[i] = ""; + } + } + + resume(); + } + + void set_prefix(bool prefix) { + std::lock_guard lock(mtx); + + this->prefix = prefix; + } + + void set_timestamps(bool timestamps) { + std::lock_guard lock(mtx); + + this->timestamps = timestamps; + } +}; + +// +// public API +// + +struct common_log * common_log_init() { + return new common_log; +} + +struct common_log * common_log_main() { + static struct common_log log; + static std::once_flag init_flag; + std::call_once(init_flag, [&]() { + // Set default to auto-detect colors + log.set_colors(common_log_should_use_colors_auto()); + }); + + return &log; +} + +void common_log_pause(struct common_log * log) { + log->pause(); +} + +void common_log_resume(struct common_log * log) { + log->resume(); +} + +void common_log_free(struct common_log * log) { + delete log; +} + +void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...) { + va_list args; + va_start(args, fmt); + log->add(level, fmt, args); + va_end(args); +} + +void common_log_set_file(struct common_log * log, const char * file) { + log->set_file(file); +} + +void common_log_set_colors(struct common_log * log, log_colors colors) { + if (colors == LOG_COLORS_AUTO) { + log->set_colors(common_log_should_use_colors_auto()); + return; + } + + if (colors == LOG_COLORS_DISABLED) { + log->set_colors(false); + return; + } + + GGML_ASSERT(colors == LOG_COLORS_ENABLED); + log->set_colors(true); +} + +void common_log_set_prefix(struct common_log * log, bool prefix) { + log->set_prefix(prefix); +} + +void common_log_set_timestamps(struct common_log * log, bool timestamps) { + log->set_timestamps(timestamps); +} + +static int common_get_verbosity(enum ggml_log_level level) { + switch (level) { + case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG; + case GGML_LOG_LEVEL_INFO: return LOG_LEVEL_INFO; + case GGML_LOG_LEVEL_WARN: return LOG_LEVEL_WARN; + case GGML_LOG_LEVEL_ERROR: return LOG_LEVEL_ERROR; + case GGML_LOG_LEVEL_CONT: return LOG_LEVEL_INFO; // same as INFO + case GGML_LOG_LEVEL_NONE: + default: + return LOG_LEVEL_OUTPUT; + } +} + +void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) { + auto verbosity = common_get_verbosity(level); + if (verbosity <= common_log_verbosity_thold) { + common_log_add(common_log_main(), level, "%s", text); + } +} diff --git a/common/log.h b/common/log.h index 2cbf4a74..596352e6 100644 --- a/common/log.h +++ b/common/log.h @@ -1,4 +1,5 @@ #pragma once +#include "ggml.h" // for ggml_log_level #include #include #include @@ -8,6 +9,124 @@ #include #include + + + +#define LOG_CLR_TO_EOL "\033[K\r" +#define LOG_COL_DEFAULT "\033[0m" +#define LOG_COL_BOLD "\033[1m" +#define LOG_COL_RED "\033[31m" +#define LOG_COL_GREEN "\033[32m" +#define LOG_COL_YELLOW "\033[33m" +#define LOG_COL_BLUE "\033[34m" +#define LOG_COL_MAGENTA "\033[35m" +#define LOG_COL_CYAN "\033[36m" +#define LOG_COL_WHITE "\033[37m" + +#ifndef __GNUC__ +# define LOG_ATTRIBUTE_FORMAT(...) +#elif defined(__MINGW32__) && !defined(__clang__) +# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif + +#define LOG_LEVEL_DEBUG 4 +#define LOG_LEVEL_INFO 3 +#define LOG_LEVEL_WARN 2 +#define LOG_LEVEL_ERROR 1 +#define LOG_LEVEL_OUTPUT 0 // output data from tools + +#define LOG_DEFAULT_DEBUG LOG_LEVEL_DEBUG +#define LOG_DEFAULT_LLAMA LOG_LEVEL_INFO + +enum log_colors { + LOG_COLORS_AUTO = -1, + LOG_COLORS_DISABLED = 0, + LOG_COLORS_ENABLED = 1, +}; + +// needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower +// set via common_log_set_verbosity() +extern int common_log_verbosity_thold; + +void common_log_set_verbosity_thold(int verbosity); // not thread-safe + +void common_log_default_callback(enum ggml_log_level level, const char* text, void* user_data); + +// the common_log uses an internal worker thread to print/write log messages +// when the worker thread is paused, incoming log messages are discarded +struct common_log; + +struct common_log* common_log_init(); +struct common_log* common_log_main(); // singleton, automatically destroys itself on exit +void common_log_pause(struct common_log* log); // pause the worker thread, not thread-safe +void common_log_resume(struct common_log* log); // resume the worker thread, not thread-safe +void common_log_free(struct common_log* log); + +LOG_ATTRIBUTE_FORMAT(3, 4) +void common_log_add(struct common_log* log, enum ggml_log_level level, const char* fmt, ...); + +// defaults: file = NULL, colors = false, prefix = false, timestamps = false +// +// regular log output: +// +// ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34) +// llm_load_tensors: ggml ctx size = 0.27 MiB +// llm_load_tensors: offloading 32 repeating layers to GPU +// llm_load_tensors: offloading non-repeating layers to GPU +// +// with prefix = true, timestamps = true, the log output will look like this: +// +// 0.00.035.060 D ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34) +// 0.00.035.064 I llm_load_tensors: ggml ctx size = 0.27 MiB +// 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU +// 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU +// +// D - debug (stderr, V = LOG_DEFAULT_DEBUG) +// I - info (stdout, V = LOG_DEFAULT_INFO) +// W - warning (stderr, V = LOG_DEFAULT_WARN) +// E - error (stderr, V = LOG_DEFAULT_ERROR) +// O - output (stdout, V = LOG_DEFAULT_OUTPUT) +// + +void common_log_set_file(struct common_log* log, const char* file); // not thread-safe +void common_log_set_colors(struct common_log* log, log_colors colors); // not thread-safe +void common_log_set_prefix(struct common_log* log, bool prefix); // whether to output prefix to each log +void common_log_set_timestamps(struct common_log* log, bool timestamps); // whether to output timestamps in the prefix + +// helper macros for logging +// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold +// +// for example: +// +// LOG_DBG("this is a debug message: %d\n", expensive_function()); +// +// this will avoid calling expensive_function() if LOG_DEFAULT_DEBUG > common_log_verbosity_thold +// + +#define LOG_TMPL(level, verbosity, ...) \ + do { \ + if ((verbosity) <= common_log_verbosity_thold) { \ + common_log_add(common_log_main(), (level), __VA_ARGS__); \ + } \ + } while (0) + +#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, LOG_LEVEL_OUTPUT, __VA_ARGS__) +#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__) + +#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_LEVEL_DEBUG, __VA_ARGS__) +#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, LOG_LEVEL_INFO, __VA_ARGS__) +#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, LOG_LEVEL_WARN, __VA_ARGS__) +#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, LOG_LEVEL_ERROR, __VA_ARGS__) +#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, LOG_LEVEL_INFO, __VA_ARGS__) // same as INFO + +#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__) +#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__) +#define LOG_ERRV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, verbosity, __VA_ARGS__) +#define LOG_DBGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, verbosity, __VA_ARGS__) +#define LOG_CNTV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_CONT, verbosity, __VA_ARGS__) + // -------------------------------- // // Basic usage: @@ -293,11 +412,11 @@ inline std::string log_filename_generator_impl(LogTriState multilog, const std:: // Main LOG macro. // behaves like printf, and supports arguments the exact same way. // -#if !defined(_MSC_VER) || defined(__clang__) - #define LOG(...) LOG_IMPL(__VA_ARGS__, "") -#else - #define LOG(str, ...) LOG_IMPL("%s" str, "", ##__VA_ARGS__, "") -#endif +//#if !defined(_MSC_VER) || defined(__clang__) +// #define LOG(...) LOG_IMPL(__VA_ARGS__, "") +//#else +// #define LOG(str, ...) LOG_IMPL("%s" str, "", ##__VA_ARGS__, "") +//#endif // Main TEE macro. // does the same as LOG @@ -721,3 +840,4 @@ inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch) #define LOG_DUMP_CMDLINE(...) // dummy stub #endif // LOG_DISABLE_LOGS + diff --git a/examples/rpc/rpc-server.cpp b/examples/rpc/rpc-server.cpp index fc134231..e67ca7f3 100644 --- a/examples/rpc/rpc-server.cpp +++ b/examples/rpc/rpc-server.cpp @@ -38,10 +38,29 @@ namespace fs = std::filesystem; // NOTE: this is copied from common.cpp to avoid linking with libcommon // returns true if successful, false otherwise + +#ifdef _WIN32 +static std::wstring utf8_to_wstring(const std::string& str) { + if (str.empty()) { + return std::wstring(); + } + + int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0); + + if (size <= 0) { + return std::wstring(); + } + + std::wstring wstr(size, 0); + MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size); + + return wstr; +} +#endif + static bool fs_create_directory_with_parents(const std::string& path) { #ifdef _WIN32 - std::wstring_convert> converter; - std::wstring wpath = converter.from_bytes(path); + std::wstring wpath = utf8_to_wstring(path); // if the path already exists, check whether it's a directory const DWORD attributes = GetFileAttributesW(wpath.c_str()); diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index d0c316d6..28b78aec 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -12,8 +12,15 @@ endif() set(TARGET_SRCS server.cpp - utils.hpp httplib.h + server-task.cpp + server-task.h + server-queue.cpp + server-queue.h + server-common.cpp + server-common.h + server-context.cpp + server-context.h ) set(PUBLIC_ASSETS index.html.gz diff --git a/examples/server/utils.hpp b/examples/server/server-common.cpp similarity index 76% rename from examples/server/utils.hpp rename to examples/server/server-common.cpp index fdd41e1b..6be7ff5f 100644 --- a/examples/server/utils.hpp +++ b/examples/server/server-common.cpp @@ -1,119 +1,29 @@ -#pragma once - -#include "llama.h" -#include -#include "common.h" - -// Change JSON_ASSERT from assert() to GGML_ASSERT: -#define JSON_ASSERT GGML_ASSERT -#include -#include "base64.hpp" -#include "mtmd.h" -#include "mtmd-helper.h" -#include "chat.h" -#include -#include -#include -#include -#include - -// increase max payload length to allow use of larger context size -#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 -// increase backlog size to avoid connection resets for >> 1 slots -#define CPPHTTPLIB_LISTEN_BACKLOG 512 -// increase max URI length to handle longer prompts in query string -#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768 -// disable Nagle's algorithm -#define CPPHTTPLIB_TCP_NODELAY true -#include "httplib.h" - -#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" - -using json = nlohmann::ordered_json; - -// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 -enum error_type { - ERROR_TYPE_INVALID_REQUEST, - ERROR_TYPE_AUTHENTICATION, - ERROR_TYPE_SERVER, - ERROR_TYPE_NOT_FOUND, - ERROR_TYPE_PERMISSION, - ERROR_TYPE_UNAVAILABLE, // custom error - ERROR_TYPE_NOT_SUPPORTED, // custom error -}; - -extern bool server_verbose; -extern bool server_log_json; - -#ifndef SERVER_VERBOSE -#define SERVER_VERBOSE 1 -#endif - -#if SERVER_VERBOSE != 1 -#define LOG_VERBOSE(MSG, ...) -#else -#define LOG_VERBOSE(MSG, ...) \ - do \ - { \ - if (server_verbose) \ - { \ - server_log("VERB", __func__, __LINE__, MSG, __VA_ARGS__); \ - } \ - } while (0) -#endif - -#define LOG_ERROR( MSG, ...) server_log("ERR", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) +#include "server-common.h" using raw_buffer = std::vector; -static inline void server_log(const char * level, const char * function, int line, const char * message, const json & extra); -template -static T json_value(const json & body, const std::string & key, const T & default_value) { - // Fallback null to default value - if (body.contains(key) && !body.at(key).is_null()) { - try { - return body.at(key); - } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const& err) { - std::stringstream ss; - ss << "Wrong type supplied for parameter '" << key << "'. Expected '" << json(default_value).type_name() << "', using default value: "<< err.what(); - LOG_WARNING(ss.str().c_str(), body); - return default_value; - } - } else { - return default_value; +server_grammar_trigger::server_grammar_trigger(const json& in) { + value.type = (common_grammar_trigger_type)in.at("type").get(); + value.value = in.at("value").get(); + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + value.token = (llama_token)in.at("token").get(); } } -// thin wrapper around common_grammar_trigger with (de)serialization functions -struct server_grammar_trigger { - common_grammar_trigger value; - - server_grammar_trigger() = default; - server_grammar_trigger(const common_grammar_trigger& value) : value(value) {} - server_grammar_trigger(const json& in) { - value.type = (common_grammar_trigger_type)in.at("type").get(); - value.value = in.at("value").get(); - if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - value.token = (llama_token)in.at("token").get(); - } +json server_grammar_trigger::to_json() const { + json out{ + {"type", (int)value.type}, + {"value", value.value}, + }; + if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { + out["token"] = (int)value.token; } + return out; +} - json to_json() const { - json out{ - {"type", (int)value.type}, - {"value", value.value}, - }; - if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - out["token"] = (int)value.token; - } - return out; - } -}; -static inline void server_log(const char * level, const char * function, int line, const char * message, const json & extra) { +void server_log(const char* level, const char* function, int line, const char* message, const json& extra) { std::stringstream ss_tid; ss_tid << std::this_thread::get_id(); json log = json{ @@ -127,14 +37,15 @@ static inline void server_log(const char * level, const char * function, int lin {"function", function}, {"line", line}, {"msg", message}, - }); + }); if (!extra.empty()) { log.merge_patch(extra); } printf("%s\n", log.dump(-1, ' ', false, json::error_handler_t::replace).c_str()); - } else { + } + else { char buf[1024]; snprintf(buf, 1024, "%4s [%24s] %s", level, function, message); @@ -143,7 +54,7 @@ static inline void server_log(const char * level, const char * function, int lin } std::stringstream ss; ss << buf << " |"; - for (const auto & el : log.items()) + for (const auto& el : log.items()) { const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace); ss << " " << el.key() << "=" << value; @@ -159,20 +70,12 @@ static inline void server_log(const char * level, const char * function, int lin // chat template utils // -// -// base64 utils (TODO: move to common in the future) -// -static const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; - -static inline bool is_base64(uint8_t c) { +bool is_base64(uint8_t c) { return (isalnum(c) || (c == '+') || (c == '/')); } -static inline std::vector base64_decode(const std::string & encoded_string) { +std::vector base64_decode(const std::string& encoded_string) { int i = 0; int j = 0; int in_ = 0; @@ -191,9 +94,9 @@ static inline std::vector base64_decode(const std::string & encoded_str char_array_4[i] = base64_chars.find(char_array_4[i]); } - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4); char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; for (i = 0; (i < 3); i++) { ret.push_back(char_array_3[i]); @@ -212,9 +115,9 @@ static inline std::vector base64_decode(const std::string & encoded_str char_array_4[j] = base64_chars.find(char_array_4[j]); } - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4); char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; for (j = 0; j < i - 1; j++) { ret.push_back(char_array_3[j]); @@ -228,7 +131,7 @@ static inline std::vector base64_decode(const std::string & encoded_str // random string / id // -static std::string random_string() { +std::string random_string() { static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); std::random_device rd; @@ -243,33 +146,33 @@ static std::string random_string() { return result; } -static std::string gen_chatcmplid() { +std::string gen_chatcmplid() { std::stringstream chatcmplid; chatcmplid << "chatcmpl-" << random_string(); return chatcmplid.str(); } -static std::string gen_tool_call_id() { +std::string gen_tool_call_id() { return random_string(); } // // other common utils // -static float get_slot_similarity(size_t lcp, size_t prompt_length, size_t cache_length) { +float get_slot_similarity(size_t lcp, size_t prompt_length, size_t cache_length) { float sim = float(lcp) * 2 / (prompt_length + cache_length); return sim; } -static size_t common_part(const std::vector & a, const std::vector & b) { +size_t common_part(const std::vector& a, const std::vector& b) { size_t i; for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} return i; } -static size_t common_part(const std::string & a, const std::string & b) { +size_t common_part(const std::string& a, const std::string& b) { size_t i; for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} @@ -279,7 +182,7 @@ static size_t common_part(const std::string & a, const std::string & b) { // return the last index of character that can form a valid string // if the last character is potentially cut in half, return the index before the cut // if validate_utf8(text) == text.size(), then the whole text is valid utf8 -static size_t validate_utf8(const std::string& text) { +size_t validate_utf8(const std::string& text) { size_t len = text.size(); if (len == 0) return 0; @@ -291,12 +194,12 @@ static size_t validate_utf8(const std::string& text) { // 2-byte character start: 110xxxxx // Needs at least 2 bytes if (i < 2) return len - i; - } + } else if ((c & 0xF0) == 0xE0) { // 3-byte character start: 1110xxxx // Needs at least 3 bytes if (i < 3) return len - i; - } + } else if ((c & 0xF8) == 0xF0) { // 4-byte character start: 11110xxx // Needs at least 4 bytes @@ -310,7 +213,7 @@ static size_t validate_utf8(const std::string& text) { // TODO: reuse llama_detokenize template -static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { +static std::string tokens_to_str(llama_context* ctx, Iter begin, Iter end) { std::string ret; for (; begin != end; ++begin) { ret += llama_token_to_piece(ctx, *begin); @@ -319,8 +222,12 @@ static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { return ret; } +std::string tokens_to_str(llama_context* ctx, const llama_tokens& tokens) { + return tokens_to_str(ctx, tokens.begin(), tokens.end()); +} + // format incomplete utf-8 multibyte character for output -static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { + std::string tokens_to_output_formatted_string(const llama_context* ctx, const llama_token token) { std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token); // if the size is 1 and first bit is 1, meaning it's a partial character @@ -335,10 +242,6 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx, return out; } -struct common_prefix { - size_t first = 0; - size_t second = 0; -}; common_prefix common_prefix_add(const common_prefix& a, const common_prefix& b) { common_prefix prefix; @@ -347,7 +250,7 @@ common_prefix common_prefix_add(const common_prefix& a, const common_prefix& b) return prefix; } -common_prefix find_common_string_prefix(const std::string & a_str, const std::string & b_str, const std::set& ignore_set) { +common_prefix find_common_string_prefix(const std::string& a_str, const std::string& b_str, const std::set& ignore_set) { size_t i = 0; size_t j = 0; while (i < a_str.size() && j < b_str.size()) { @@ -378,7 +281,7 @@ common_prefix find_common_string_prefix(const std::string & a_str, const std::st } size_t find_n_tokens_from_string(const llama_context* ctx, const llama_tokens& a, const size_t max_size, size_t start, - std::vector & map) { + std::vector& map) { size_t n = 0; size_t string_len = 0; std::string str; @@ -429,9 +332,9 @@ common_prefix find_largest_common_number(const std::vector& a_list, cons return token_prefix; } -size_t find_n_tokens_from_string_with_ignore(const llama_context* ctx, const llama_tokens& a, const size_t max_size, size_t start, const std::set & ignore_set, +size_t find_n_tokens_from_string_with_ignore(const llama_context* ctx, const llama_tokens& a, const size_t max_size, size_t start, const std::set& ignore_set, std::vector& map) { - bool use_ignore = ignore_set.size()>0; + bool use_ignore = ignore_set.size() > 0; size_t n = 0; size_t string_len = 0; size_t string_len_ignore = 0; @@ -458,13 +361,13 @@ size_t find_n_tokens_from_string_with_ignore(const llama_context* ctx, const lla return map.size(); } -common_prefix find_common_text_token_prefix(const llama_context * ctx, const llama_tokens & a, const llama_tokens& b, +common_prefix find_common_text_token_prefix(const llama_context* ctx, const llama_tokens& a, const llama_tokens& b, size_t start, bool exact) { common_prefix token_prefix; - if (a.size()<= start || b.size()<= start) { + if (a.size() <= start || b.size() <= start) { return token_prefix; } - std::set ignore_set = { ' ', '\n' ,'\r'}; + std::set ignore_set = { ' ', '\n' ,'\r' }; llama_tokens a_sub(a.begin() + start, a.end()); llama_tokens b_sub(b.begin() + start, b.end()); @@ -494,99 +397,88 @@ common_prefix find_common_text_token_prefix(const llama_context * ctx, const lla } -struct completion_token_output { - llama_token tok; - std::string text_to_send; - float prob; - struct prob_info { - llama_token tok; - std::string txt; - float prob; - }; - std::vector probs; - - json to_json(bool post_sampling_probs) const { - json probs_for_token = json::array(); - for (const auto& p : probs) { - std::string txt(p.txt); - txt.resize(validate_utf8(txt)); - probs_for_token.push_back(json{ - {"id", p.tok}, - {"token", txt}, - {"bytes", str_to_bytes(p.txt)}, - { - post_sampling_probs ? "prob" : "logprob", - post_sampling_probs ? p.prob : logarithm(p.prob) - }, - }); - } - return probs_for_token; +json completion_token_output::to_json(bool post_sampling_probs) const { + json probs_for_token = json::array(); + for (const auto& p : probs) { + std::string txt(p.txt); + txt.resize(validate_utf8(txt)); + probs_for_token.push_back(json{ + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.txt)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + }); } + return probs_for_token; +} - static float logarithm(float x) { - // nlohmann::json converts -inf to null, so we need to prevent that - return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); + float completion_token_output::logarithm(float x) { + // nlohmann::json converts -inf to null, so we need to prevent that + return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); +} + +std::vector completion_token_output::str_to_bytes(const std::string& str) { + std::vector bytes; + for (unsigned char c : str) { + bytes.push_back(c); } + return bytes; +} - static std::vector str_to_bytes(const std::string& str) { - std::vector bytes; - for (unsigned char c : str) { - bytes.push_back(c); - } - return bytes; + +json completion_token_output::probs_vector_to_json(const std::vector& probs, bool post_sampling_probs) { + json out = json::array(); + for (const auto& p : probs) { + std::string txt(p.text_to_send); + txt.resize(validate_utf8(txt)); + out.push_back(json{ + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.text_to_send)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + { + post_sampling_probs ? "top_probs" : "top_logprobs", + p.to_json(post_sampling_probs) + }, + }); } + return out; +} - static json probs_vector_to_json(const std::vector& probs, bool post_sampling_probs) { - json out = json::array(); - for (const auto& p : probs) { - std::string txt(p.text_to_send); - txt.resize(validate_utf8(txt)); - out.push_back(json{ - {"id", p.tok}, - {"token", txt}, - {"bytes", str_to_bytes(p.text_to_send)}, - { - post_sampling_probs ? "prob" : "logprob", - post_sampling_probs ? p.prob : logarithm(p.prob) - }, - { - post_sampling_probs ? "top_probs" : "top_logprobs", - p.to_json(post_sampling_probs) - }, - }); - } - return out; - } -}; - // convert a vector of completion_token_output to json -static json probs_vector_to_json(const llama_context * ctx, const std::vector & probs) { +json probs_vector_to_json(const llama_context* ctx, const std::vector& probs) { json out = json::array(); - for (const auto & prob : probs) { + for (const auto& prob : probs) { json probs_for_token = json::array(); - for (const auto & p : prob.probs) { + for (const auto& p : prob.probs) { const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok); - probs_for_token.push_back(json { + probs_for_token.push_back(json{ {"tok_str", tok_str}, {"prob", p.prob}, - }); + }); } const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok); - out.push_back(json { + out.push_back(json{ {"content", tok_str}, {"probs", probs_for_token}, - }); + }); } return out; } -static bool server_sent_event(httplib::DataSink& sink, const json& data) { +bool server_sent_event(httplib::DataSink& sink, const json& data) { const std::string str = "data: " + data.dump(-1, ' ', false, json::error_handler_t::replace) + @@ -597,11 +489,11 @@ static bool server_sent_event(httplib::DataSink& sink, const json& data) { return sink.write(str.c_str(), str.size()); } -static bool server_sent_anthropic_event(httplib::DataSink& sink, const json& data) { +bool server_sent_anthropic_event(httplib::DataSink& sink, const json& data) { const std::string str = - (data.contains("event") && data.contains("data"))? + (data.contains("event") && data.contains("data")) ? ("event: " + data.at("event").get() + "\n" + - "data: " + data.at("data").dump(-1, ' ', false, json::error_handler_t::replace) + "\n\n"): + "data: " + data.at("data").dump(-1, ' ', false, json::error_handler_t::replace) + "\n\n") : ("data: " + data.at("data").dump(-1, ' ', false, json::error_handler_t::replace) + "\n\n"); LOG_VERBOSE("data stream, to_send: %s", str.c_str()); @@ -613,7 +505,7 @@ static bool server_sent_anthropic_event(httplib::DataSink& sink, const json& dat // OAI utils // // used by /completions endpoint -static json oaicompat_chat_params_parse(const json& body) { +json oaicompat_chat_params_parse(const json& body) { json llama_params; if (!body.contains("prompt")) { @@ -664,19 +556,9 @@ static json oaicompat_chat_params_parse(const json& body) { return llama_params; } -struct oaicompat_parser_options { - bool use_jinja; - bool prefill_assistant; - common_reasoning_format reasoning_format; - std::map chat_template_kwargs; - common_chat_templates* tmpls; - bool allow_image; - bool allow_audio; - bool enable_thinking = true; -}; // used by /chat/completions endpoint -static json oaicompat_chat_params_parse( +json oaicompat_chat_params_parse( const struct llama_model* model, json& body, /* openai api json semantics */ const oaicompat_parser_options& opt, @@ -878,7 +760,7 @@ static json oaicompat_chat_params_parse( /*"whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)\n" "when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled\n"*/ - bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" &&opt.prefill_assistant; + bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant; common_chat_msg last_message; if (prefill_assistant_message) { last_message = inputs.messages.back(); @@ -899,14 +781,15 @@ static json oaicompat_chat_params_parse( // Apply chat template to the list of messages auto chat_params = common_chat_templates_apply(opt.tmpls, inputs); - + /* Append assistant prefilled message */ if (prefill_assistant_message) { if (!last_message.content_parts.empty()) { - for (auto & p : last_message.content_parts) { + for (auto& p : last_message.content_parts) { chat_params.prompt += p.text; } - } else { + } + else { chat_params.prompt += last_message.content; } } @@ -918,7 +801,7 @@ static json oaicompat_chat_params_parse( } llama_params["grammar_lazy"] = chat_params.grammar_lazy; auto grammar_triggers = json::array(); - for (const auto & trigger : chat_params.grammar_triggers) { + for (const auto& trigger : chat_params.grammar_triggers) { server_grammar_trigger ct(trigger); grammar_triggers.push_back(ct.to_json()); } @@ -942,7 +825,8 @@ static json oaicompat_chat_params_parse( throw std::runtime_error("logprobs is not supported with tools + stream"); } llama_params["n_probs"] = json_value(body, "top_logprobs", 20); - } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { + } + else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { throw std::runtime_error("top_logprobs requires logprobs to be set to true"); } @@ -960,25 +844,27 @@ static json oaicompat_chat_params_parse( return llama_params; } -static json anthropic_params_from_json( +json anthropic_params_from_json( const struct llama_model* model, - const json & body_in, /* anthropic messages api json semantics */ - const oaicompat_parser_options & opt, - std::vector & out_files) + const json& body_in, /* anthropic messages api json semantics */ + const oaicompat_parser_options& opt, + std::vector& out_files) { json body = body_in; json llama_params; if (body.contains("stop_sequences")) { llama_params["stop"] = body.at("stop_sequences"); - } else { + } + else { llama_params["stop"] = json::array(); } // handle max_tokens (required in Anthropic, but we're permissive) if (!body.contains("max_tokens")) { llama_params["n_predict"] = 4096; - } else { + } + else { llama_params["n_predict"] = body.at("max_tokens"); } @@ -1010,8 +896,9 @@ static json anthropic_params_from_json( if (system_param.is_string()) { system_content = system_param.get(); - } else if (system_param.is_array()) { - for (const auto & block : system_param) { + } + else if (system_param.is_array()) { + for (const auto& block : system_param) { if (json_value(block, "type", std::string()) == "text") { system_content += json_value(block, "text", std::string()); } @@ -1021,18 +908,18 @@ static json anthropic_params_from_json( oai_messages.push_back({ {"role", "system"}, {"content", system_content} - }); + }); } if (!body.contains("messages")) { throw std::runtime_error("'messages' is required"); } - json & messages = body.at("messages"); + json& messages = body.at("messages"); if (!messages.is_array()) { throw std::runtime_error("Expected 'messages' to be an array"); } - for (auto & msg : messages) { + for (auto& msg : messages) { std::string role = json_value(msg, "role", std::string()); if (role != "assistant" && !msg.contains("content")) { throw std::runtime_error("All non-assistant messages must contain 'content'"); @@ -1043,7 +930,7 @@ static json anthropic_params_from_json( } } - json & content = msg.at("content"); + json& content = msg.at("content"); if (content.is_string()) { oai_messages.push_back(msg); @@ -1059,12 +946,13 @@ static json anthropic_params_from_json( json tool_results = json::array(); bool has_tool_calls = false; - for (auto & block : content) { + for (auto& block : content) { std::string type = json_value(block, "type", std::string()); if (type == "text") { converted_content.push_back(block); - } else if (type == "image") { + } + else if (type == "image") { json source = json_value(block, "source", json::object()); std::string source_type = json_value(source, "type", std::string()); @@ -1077,17 +965,19 @@ static json anthropic_params_from_json( {"image_url", { {"url", "data:" + media_type + ";base64," + data} }} - }); - } else if (source_type == "url") { + }); + } + else if (source_type == "url") { std::string url = json_value(source, "url", std::string()); converted_content.push_back({ {"type", "image_url"}, {"image_url", { {"url", url} }} - }); + }); } - } else if (type == "tool_use") { + } + else if (type == "tool_use") { tool_calls.push_back({ {"id", json_value(block, "id", std::string())}, {"type", "function"}, @@ -1095,17 +985,19 @@ static json anthropic_params_from_json( {"name", json_value(block, "name", std::string())}, {"arguments", json_value(block, "input", json::object()).dump()} }} - }); + }); has_tool_calls = true; - } else if (type == "tool_result") { + } + else if (type == "tool_result") { std::string tool_use_id = json_value(block, "tool_use_id", std::string()); auto result_content = json_value(block, "content", json()); std::string result_text; if (result_content.is_string()) { result_text = result_content.get(); - } else if (result_content.is_array()) { - for (const auto & c : result_content) { + } + else if (result_content.is_array()) { + for (const auto& c : result_content) { if (json_value(c, "type", std::string()) == "text") { result_text += json_value(c, "text", std::string()); } @@ -1116,16 +1008,17 @@ static json anthropic_params_from_json( {"role", "tool"}, {"tool_call_id", tool_use_id}, {"content", result_text} - }); + }); } } if (!tool_results.empty()) { if (!converted_content.empty() || has_tool_calls) { - json new_msg = {{"role", role}}; + json new_msg = { {"role", role} }; if (!converted_content.empty()) { new_msg["content"] = converted_content; - } else if (has_tool_calls) { + } + else if (has_tool_calls) { new_msg["content"] = ""; } if (!tool_calls.empty()) { @@ -1133,15 +1026,17 @@ static json anthropic_params_from_json( } oai_messages.push_back(new_msg); } - for (const auto & tool_msg : tool_results) { + for (const auto& tool_msg : tool_results) { oai_messages.push_back(tool_msg); } - } else { + } + else { if (!converted_content.empty() || has_tool_calls) { - json new_msg = {{"role", role}}; + json new_msg = { {"role", role} }; if (!converted_content.empty()) { new_msg["content"] = converted_content; - } else if (has_tool_calls) { + } + else if (has_tool_calls) { new_msg["content"] = ""; } if (!tool_calls.empty()) { @@ -1154,9 +1049,9 @@ static json anthropic_params_from_json( json oai_tools = json::array(); if (body.contains("tools")) { - json & tools = body.at("tools"); + json& tools = body.at("tools"); if (tools.is_array()) { - for (auto & tool : tools) { + for (auto& tool : tools) { oai_tools.push_back({ {"type", "function"}, {"function", { @@ -1164,31 +1059,33 @@ static json anthropic_params_from_json( {"description", json_value(tool, "description", std::string())}, {"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()} }} - }); + }); } } } std::string oai_tool_choice = "auto"; if (body.contains("tool_choice")) { - json & tc = body.at("tool_choice"); + json& tc = body.at("tool_choice"); if (tc.is_object()) { std::string type = json_value(tc, "type", std::string()); if (type == "auto") { oai_tool_choice = "auto"; - } else if (type == "any") { + } + else if (type == "any") { oai_tool_choice = "required"; - } else if (type == "tool") { + } + else if (type == "tool") { oai_tool_choice = "required"; } } } - for (auto & msg : oai_messages) { + for (auto& msg : oai_messages) { if (!msg.contains("content")) { continue; } - json & content = msg.at("content"); + json& content = msg.at("content"); if (content.is_string() || content.is_null()) { continue; } @@ -1196,21 +1093,21 @@ static json anthropic_params_from_json( continue; } - for (auto & p : content) { + for (auto& p : content) { std::string type = json_value(p, "type", std::string()); if (type == "image_url") { if (!opt.allow_image) { throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); } - json image_url = json_value(p, "image_url", json::object()); + json image_url = json_value(p, "image_url", json::object()); std::string url = json_value(image_url, "url", std::string()); if (string_starts_with(url, "http")) { // download remote image common_remote_params params; params.headers.push_back("User-Agent: ik_llama.cpp/"); params.max_size = 1024 * 1024 * 10; // 10MB - params.timeout = 10; // seconds + params.timeout = 10; // seconds LOG_INFO("downloading image from '%s'\n", url.c_str()); auto res = common_remote_get_content(url, params); if (200 <= res.first && res.first < 300) { @@ -1218,19 +1115,24 @@ static json anthropic_params_from_json( raw_buffer data; data.insert(data.end(), res.second.begin(), res.second.end()); out_files.push_back(data); - } else { + } + else { throw std::runtime_error("Failed to download image"); } - } else { + } + else { // try to decode base64 image std::vector parts = string_split(url, /*separator*/ ','); if (parts.size() != 2) { throw std::runtime_error("Invalid image_url.url value"); - } else if (!string_starts_with(parts[0], "data:image/")) { + } + else if (!string_starts_with(parts[0], "data:image/")) { throw std::runtime_error("Invalid image_url.url format: " + parts[0]); - } else if (!string_ends_with(parts[0], "base64")) { + } + else if (!string_ends_with(parts[0], "base64")) { throw std::runtime_error("image_url.url must be base64 encoded"); - } else { + } + else { auto base64_data = parts[1]; auto decoded_data = base64_decode(base64_data); out_files.push_back(decoded_data); @@ -1241,13 +1143,14 @@ static json anthropic_params_from_json( p["type"] = "text"; p["text"] = mtmd_default_marker(); p.erase("image_url"); - } else if (type == "input_audio") { + } + else if (type == "input_audio") { if (!opt.allow_audio) { throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); } - json input_audio = json_value(p, "input_audio", json::object()); - std::string data = json_value(input_audio, "data", std::string()); + json input_audio = json_value(p, "input_audio", json::object()); + std::string data = json_value(input_audio, "data", std::string()); std::string format = json_value(input_audio, "format", std::string()); if (format != "wav" && format != "mp3") { throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'"); @@ -1264,16 +1167,16 @@ static json anthropic_params_from_json( } common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(oai_messages); - inputs.tools = common_chat_tools_parse_oaicompat(oai_tools); - inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(oai_tool_choice); - inputs.json_schema = ""; - inputs.grammar = ""; - inputs.use_jinja = opt.use_jinja; - inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + inputs.messages = common_chat_msgs_parse_oaicompat(oai_messages); + inputs.tools = common_chat_tools_parse_oaicompat(oai_tools); + inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(oai_tool_choice); + inputs.json_schema = ""; + inputs.grammar = ""; + inputs.use_jinja = opt.use_jinja; + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); - inputs.reasoning_format = opt.reasoning_format; - inputs.enable_thinking = opt.enable_thinking; + inputs.reasoning_format = opt.reasoning_format; + inputs.enable_thinking = opt.enable_thinking; if (opt.enable_thinking && opt.prefill_assistant) { if (!inputs.messages.empty() && inputs.messages.back().role == "assistant") { @@ -1288,7 +1191,7 @@ static json anthropic_params_from_json( // merge the template args provided from command line with the args provided in the user request auto chat_template_kwargs_object = json_value(body, "chat_template_kwargs", json::object()); inputs.chat_template_kwargs = opt.chat_template_kwargs; - for (const auto & item : chat_template_kwargs_object.items()) { + for (const auto& item : chat_template_kwargs_object.items()) { inputs.chat_template_kwargs[item.key()] = item.value().dump(); } @@ -1296,9 +1199,11 @@ static json anthropic_params_from_json( auto enable_thinking_kwarg = json_value(inputs.chat_template_kwargs, "enable_thinking", std::string("")); if (enable_thinking_kwarg == "true") { inputs.enable_thinking = true; - } else if (enable_thinking_kwarg == "false") { + } + else if (enable_thinking_kwarg == "false") { inputs.enable_thinking = false; - } else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') { + } + else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') { throw std::runtime_error("invalid type for \"enable_thinking\" (expected boolean, got string)"); } @@ -1310,7 +1215,7 @@ static json anthropic_params_from_json( inputs.messages.pop_back(); // sanity check, max one assistant message at the end of the list - if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){ + if (!inputs.messages.empty() && inputs.messages.back().role == "assistant") { throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list."); } @@ -1329,29 +1234,30 @@ static json anthropic_params_from_json( // Append assistant prefilled message if (prefill_assistant_message) { if (!last_message.content_parts.empty()) { - for (auto & p : last_message.content_parts) { + for (auto& p : last_message.content_parts) { chat_params.prompt += p.text; } - } else { + } + else { chat_params.prompt += last_message.content; } } - llama_params["chat_format"] = static_cast(chat_params.format); - llama_params["prompt"] = chat_params.prompt; + llama_params["chat_format"] = static_cast(chat_params.format); + llama_params["prompt"] = chat_params.prompt; if (!chat_params.grammar.empty()) { llama_params["grammar"] = chat_params.grammar; } - llama_params["grammar_lazy"] = chat_params.grammar_lazy; + llama_params["grammar_lazy"] = chat_params.grammar_lazy; auto grammar_triggers = json::array(); - for (const auto & trigger : chat_params.grammar_triggers) { + for (const auto& trigger : chat_params.grammar_triggers) { server_grammar_trigger ct(trigger); grammar_triggers.push_back(ct.to_json()); } llama_params["grammar_triggers"] = grammar_triggers; llama_params["preserved_tokens"] = chat_params.preserved_tokens; llama_params["thinking_forced_open"] = chat_params.thinking_forced_open; - for (const auto & stop : chat_params.additional_stops) { + for (const auto& stop : chat_params.additional_stops) { llama_params["stop"].push_back(stop); } @@ -1364,7 +1270,7 @@ static json anthropic_params_from_json( // Copy remaining properties to llama_params // This allows user to use llama.cpp-specific params like "mirostat", ... via Anthropic endpoint. // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp - for (const auto & item : body.items()) { + for (const auto& item : body.items()) { // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" if (!llama_params.contains(item.key()) || item.key() == "n_predict") { llama_params[item.key()] = item.value(); @@ -1379,7 +1285,7 @@ static json anthropic_params_from_json( // tokenizer and input processing utils // -static bool json_is_array_of_numbers(const json& data) { +bool json_is_array_of_numbers(const json& data) { if (data.is_array()) { for (const auto& e : data) { if (!e.is_number_integer()) { @@ -1392,7 +1298,7 @@ static bool json_is_array_of_numbers(const json& data) { } // is array having BOTH numbers & strings? -static bool json_is_array_of_mixed_numbers_strings(const json& data) { +bool json_is_array_of_mixed_numbers_strings(const json& data) { bool seen_string = false; bool seen_number = false; if (data.is_array()) { @@ -1408,7 +1314,7 @@ static bool json_is_array_of_mixed_numbers_strings(const json& data) { } // does array have any individual integers/tokens? -static bool json_is_array_and_contains_numbers(const json& data) { +bool json_is_array_and_contains_numbers(const json& data) { if (data.is_array()) { for (const auto& e : data) { if (e.is_number_integer()) { @@ -1421,7 +1327,7 @@ static bool json_is_array_and_contains_numbers(const json& data) { } // get value by path(key1 / key2) -static json json_get_nested_values(const std::vector& paths, const json& js) { +json json_get_nested_values(const std::vector& paths, const json& js) { json result = json::object(); for (const std::string& path : paths) { @@ -1449,7 +1355,7 @@ static json json_get_nested_values(const std::vector& paths, const * - only string, example: "string" * - mixed string and tokens, example: [12, 34, "string", 56, 78] */ -static std::vector tokenize_mixed(const llama_vocab* vocab, const json& json_prompt, bool add_special, bool parse_special) { +std::vector tokenize_mixed(const llama_vocab* vocab, const json& json_prompt, bool add_special, bool parse_special) { // If `add_bos` is true, we only add BOS, when json_prompt is a string, // or the first element of the json_prompt array is a string. std::vector prompt_tokens; @@ -1488,72 +1394,68 @@ static std::vector tokenize_mixed(const llama_vocab* vocab, const j return prompt_tokens; } -static json format_tokenizer_response(const std::vector & tokens) { - return json { +json format_tokenizer_response(const std::vector& tokens) { + return json{ {"tokens", tokens} }; } -static json format_detokenized_response(const std::string & content) { - return json { +json format_detokenized_response(const std::string& content) { + return json{ {"content", content} }; } -static json format_error_response(const std::string & message, const enum error_type type) { +json format_error_response(const std::string& message, const enum error_type type) { std::string type_str; int code = 500; switch (type) { - case ERROR_TYPE_INVALID_REQUEST: - type_str = "invalid_request_error"; - code = 400; - break; - case ERROR_TYPE_AUTHENTICATION: - type_str = "authentication_error"; - code = 401; - break; - case ERROR_TYPE_NOT_FOUND: - type_str = "not_found_error"; - code = 404; - break; - case ERROR_TYPE_SERVER: - type_str = "server_error"; - code = 500; - break; - case ERROR_TYPE_PERMISSION: - type_str = "permission_error"; - code = 403; - break; - case ERROR_TYPE_NOT_SUPPORTED: - type_str = "not_supported_error"; - code = 501; - break; - case ERROR_TYPE_UNAVAILABLE: - type_str = "unavailable_error"; - code = 503; - break; + case ERROR_TYPE_INVALID_REQUEST: + type_str = "invalid_request_error"; + code = 400; + break; + case ERROR_TYPE_AUTHENTICATION: + type_str = "authentication_error"; + code = 401; + break; + case ERROR_TYPE_NOT_FOUND: + type_str = "not_found_error"; + code = 404; + break; + case ERROR_TYPE_SERVER: + type_str = "server_error"; + code = 500; + break; + case ERROR_TYPE_PERMISSION: + type_str = "permission_error"; + code = 403; + break; + case ERROR_TYPE_NOT_SUPPORTED: + type_str = "not_supported_error"; + code = 501; + break; + case ERROR_TYPE_UNAVAILABLE: + type_str = "unavailable_error"; + code = 503; + break; } - return json { + return json{ {"code", code}, {"message", message}, {"type", type_str}, }; } -struct token_probabilities { - float sampled_token_p; - std::vector cur; -}; -static token_probabilities get_token_probabilities(llama_context * ctx, int idx, llama_token sampled_token_id, int n_sorted) { - const auto * logits = llama_get_logits_ith(ctx, idx); +token_probabilities get_token_probabilities(llama_context* ctx, int idx, llama_token sampled_token_id, int n_sorted) { + const auto* logits = llama_get_logits_ith(ctx, idx); const int n_vocab = llama_n_vocab(llama_get_model(ctx)); n_sorted = std::min(n_sorted, n_vocab); std::vector> sorted(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) sorted[token_id] = {logits[token_id], token_id}; + for (llama_token token_id = 0; token_id < n_vocab; token_id++) sorted[token_id] = { logits[token_id], token_id }; - std::partial_sort(sorted.begin(), sorted.begin() + n_sorted, sorted.end(), std::greater>{}); + std::partial_sort(sorted.begin(), sorted.begin() + n_sorted, sorted.end(), std::greater>{}); float max_l = sorted.front().first; float cum_sum = 0.0f; @@ -1564,7 +1466,7 @@ static token_probabilities get_token_probabilities(llama_context * ctx, int idx, float p = expf(sorted[i].first - max_l); cum_sum += p; if (i < n_sorted) { - cur[i] = {sorted[i].second, sorted[i].first, p}; + cur[i] = { sorted[i].second, sorted[i].first, p }; } if (!sampled_token_found && sorted[i].second == sampled_token_id) { sampled_token_p = p; @@ -1573,65 +1475,28 @@ static token_probabilities get_token_probabilities(llama_context * ctx, int idx, } for (int i = n_sorted; i < n_vocab; ++i) cum_sum += expf(sorted[i].first - max_l); - float inv_cum_sum = 1/cum_sum; + float inv_cum_sum = 1 / cum_sum; for (int i = 0; i < n_sorted; ++i) cur[i].p *= inv_cum_sum; sampled_token_p *= inv_cum_sum; - return {sampled_token_p, cur}; + return { sampled_token_p, cur }; } /** * server_tokens is a helper to manage the input tokens and image for the server. * it is made this way to simplify the logic of KV cache management. */ -struct server_tokens { - bool has_mtmd = false; -private: // disallow accessing these members directly, risking out-of-sync - - // map a **start** index in tokens to the image chunk - // note: the order need to be in-sync with tokens - std::map map_idx_to_media; - - // list of tokens - // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk - // otherwise, it is a normal text token - // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list - // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos - llama_tokens tokens; - - // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos): - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1] - // idx 0 1 2 3 4 5 6 7 8 9 10 - // pos 0 1 2 3 4 5 5 5 7 7 7 - // map_idx_to_media will contain: {5, img0}, {8, img1} - -public: - server_tokens() = default; - ~server_tokens() = default; - - // Prevent copying - server_tokens(const server_tokens&) = delete; - server_tokens& operator=(const server_tokens&) = delete; - - // Allow moving (usually implicitly generated if members are movable) - server_tokens(server_tokens&&) = default; - server_tokens& operator=(server_tokens&&) = default; - - // Allow accessing elements using [] operator - llama_token operator[](size_t index) { return tokens[index]; } - const llama_token& operator[](size_t index) const { return tokens[index]; } - - server_tokens(mtmd::input_chunks& mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { +server_tokens::server_tokens(mtmd::input_chunks& mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { for (size_t i = 0; i < mtmd_chunks.size(); ++i) { push_back(mtmd_chunks[i]); } } - server_tokens(const llama_tokens& tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) { +server_tokens::server_tokens(const llama_tokens& tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) { } - llama_pos pos_next() const { + llama_pos server_tokens::pos_next() const { if (!has_mtmd) { return tokens.size(); } @@ -1647,7 +1512,7 @@ public: } // for debugging - std::string str() const { + std::string server_tokens::str() const { std::ostringstream oss; oss << "tokens: "; for (size_t idx = 0; idx < tokens.size(); ++idx) { @@ -1668,7 +1533,7 @@ public: return oss.str(); } - const mtmd::input_chunk_ptr& find_chunk(size_t idx) const { + const mtmd::input_chunk_ptr& server_tokens::find_chunk(size_t idx) const { auto it = map_idx_to_media.find(idx); if (it != map_idx_to_media.end()) { return it->second; @@ -1676,7 +1541,7 @@ public: throw std::runtime_error("Chunk not found"); } - void push_back(llama_token tok) { + void server_tokens::push_back(llama_token tok) { if (tok == LLAMA_TOKEN_NULL) { throw std::runtime_error("Invalid token"); } @@ -1684,7 +1549,7 @@ public: } // will create a copy of the chunk if it contains non-text data - void push_back(const mtmd_input_chunk* chunk) { + void server_tokens::push_back(const mtmd_input_chunk* chunk) { auto type = mtmd_input_chunk_get_type(chunk); if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { GGML_ASSERT(has_mtmd); @@ -1709,7 +1574,7 @@ public: } // appends server tokens, updates the media map. copies media chunks. - void push_back(server_tokens& tokens) { + void server_tokens::push_back(server_tokens& tokens) { size_t start_idx = size(); for (size_t i = 0; i < tokens.size(); i++) { push_back(tokens[i]); @@ -1727,66 +1592,66 @@ public: } // for compatibility with context shift and prompt truncation - void insert(const std::vector& inp_tokens) { + void server_tokens::insert(const std::vector& inp_tokens) { GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); } // for compatibility with context shift and prompt truncation - void resize(size_t size) { + void server_tokens::resize(size_t size) { GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled tokens.resize(size); } - llama_token * data() { + llama_token* server_tokens::data() { return tokens.data(); } - llama_tokens::iterator begin() { + llama_tokens::iterator server_tokens::begin() { return tokens.begin(); } - llama_tokens::iterator end() { - return tokens.end(); + llama_tokens::iterator server_tokens::end() { + return tokens.end(); } - llama_tokens::const_iterator cbegin() { + llama_tokens::const_iterator server_tokens::cbegin() { return tokens.cbegin(); } - llama_tokens::const_iterator cend() { + llama_tokens::const_iterator server_tokens::cend() { return tokens.cend(); } - llama_tokens tokens_data() { + llama_tokens server_tokens::tokens_data() { return tokens; } // for compatibility with speculative decoding, ctx shift, slot save/load - const std::vector& get_text_tokens() const { + const std::vector& server_tokens::get_text_tokens() const { GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled return tokens; } // for compatibility with speculative decoding - void set_token(llama_pos pos, llama_token id) { + void server_tokens::set_token(llama_pos pos, llama_token id) { GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled tokens[pos] = id; } - size_t size() const { + size_t server_tokens::size() const { return tokens.size(); } - bool empty() const { + bool server_tokens::empty() const { return tokens.empty(); } - void clear() { + void server_tokens::clear() { tokens.clear(); } - void keep_first(size_t n) { + void server_tokens::keep_first(size_t n) { GGML_ASSERT(n <= tokens.size()); if (has_mtmd) { if (n == tokens.size()) { @@ -1819,18 +1684,18 @@ public: tokens.resize(n); } - std::string detokenize(const llama_context* ctx, bool special) const { + std::string server_tokens::detokenize(const llama_context* ctx, bool special) const { llama_tokens text_tokens; text_tokens.reserve(tokens.size()); for (const auto& t : tokens) { if (t != LLAMA_TOKEN_NULL) { text_tokens.push_back(t); } - } + } return llama_detokenize(ctx, text_tokens, special); } - std::string detokenize(const llama_context* ctx, bool special, size_t start, size_t length) const { + std::string server_tokens::detokenize(const llama_context* ctx, bool special, size_t start, size_t length) const { std::string str; if (tokens.size() <= start || length == 0) { return str; @@ -1840,7 +1705,7 @@ public: size_t i = 0; size_t count = 0; for (const auto& t : tokens) { - if (t != LLAMA_TOKEN_NULL && i>=start) { + if (t != LLAMA_TOKEN_NULL && i >= start) { text_tokens.push_back(t); ++count; if (count >= length) { @@ -1852,7 +1717,7 @@ public: return llama_detokenize(ctx, text_tokens, special); } - size_t find_n_from_tokens(const llama_context* ctx, const server_tokens& b, bool special, + size_t server_tokens::find_n_from_tokens(const llama_context* ctx, const server_tokens& b, bool special, size_t start, const size_t length) { std::string str = detokenize(ctx, special, start, length); std::vector tmp; @@ -1860,7 +1725,7 @@ public: return n; } - size_t get_common_prefix_exact(const server_tokens& b) const { + size_t server_tokens::get_common_prefix_exact(const server_tokens& b) const { const size_t max_idx = std::min(tokens.size(), b.tokens.size()); if (!has_mtmd) { @@ -1909,7 +1774,7 @@ public: } - common_prefix get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact = false) const { + common_prefix server_tokens::get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact) const { common_prefix token_prefix; size_t n = get_common_prefix_exact(b); // strict token match as a starting point @@ -1962,12 +1827,13 @@ public: prefix.first += n_tok_a; prefix.second += n_tok_a; token_prefix = common_prefix_add(prefix, token_prefix); - } else { + } + else { // do no include image token prefix // only return text token prefix token_prefix = common_prefix_add(prefix, token_prefix); return token_prefix; - } + } } else { // text not match @@ -1985,20 +1851,20 @@ public: // take first n tokens of tokens list a // find the common prefix between a and b - common_prefix get_common_prefix_first_n(const llama_context* ctx, const server_tokens& b, size_t n, bool exact = false) const { + common_prefix server_tokens::get_common_prefix_first_n(const llama_context* ctx, const server_tokens& b, size_t n, bool exact) const { // not work for mtmd GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled auto tokens = get_text_tokens(); if (n > tokens.size()) { n = tokens.size(); } - llama_tokens copy(tokens.begin(), tokens.begin()+n); + llama_tokens copy(tokens.begin(), tokens.begin() + n); server_tokens a = server_tokens(copy, false); return a.get_common_prefix(ctx, b, exact); } // make sure all text tokens are within the vocab range - bool validate(const struct llama_context* ctx) const { + bool server_tokens::validate(const struct llama_context* ctx) const { const llama_model* model = llama_get_model(ctx); const llama_vocab* vocab = llama_model_get_vocab(model); const int32_t n_vocab = llama_vocab_n_tokens(vocab); @@ -2023,7 +1889,7 @@ public: } // encode and decode the image chunk - int32_t process_chunk( + int32_t server_tokens::process_chunk( llama_context* ctx, mtmd_context* mctx, size_t idx, @@ -2055,7 +1921,7 @@ public: } // Keep the first n_keep and remove n_discard tokens from tokens - void discard_n_tokens(int32_t n_keep, int32_t n_discard) { + void server_tokens::discard_n_tokens(int32_t n_keep, int32_t n_discard) { if (n_discard <= 0 || n_keep + n_discard >= size()) { return; } @@ -2064,7 +1930,7 @@ public: for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { new_tokens[i - n_discard] = new_tokens[i]; } - int32_t token_size = (int32_t) size(); + int32_t token_size = (int32_t)size(); new_tokens.resize(token_size - n_discard); clear(); insert(new_tokens); @@ -2072,11 +1938,11 @@ public: } // Similarity between prompt and cached - float get_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const { + float server_tokens::get_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep, int n_discard) const { GGML_ASSERT(n_keep >= 0 && n_discard >= 0); float sim_cur = 0; if (n_keep == 0 && n_discard == 0) { - auto lcp_len= get_common_prefix(ctx, tokens); + auto lcp_len = get_common_prefix(ctx, tokens); sim_cur = get_slot_similarity(lcp_len.second, tokens.size(), size()); } else { @@ -2090,26 +1956,26 @@ public: } // Similarity between common part and cache - float get_cached_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const { + float server_tokens::get_cached_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep, int n_discard) const { GGML_ASSERT(n_keep >= 0 && n_discard >= 0); float sim_cur = 0; if (n_keep == 0 && n_discard == 0) { auto lcp_len = get_common_prefix(ctx, tokens); - sim_cur = (float) lcp_len.first/size(); + sim_cur = (float)lcp_len.first / size(); } else { // remove tokens due to context shift and compare auto tokens_ctx_shift = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens tokens_ctx_shift.discard_n_tokens(n_keep, n_discard); auto lcp_len = get_common_prefix(ctx, tokens_ctx_shift); - sim_cur = (float) lcp_len.first / size(); + sim_cur = (float)lcp_len.first / size(); } return sim_cur; } -}; + // Computes FNV-1a hash of the data -static std::string fnv_hash(const uint8_t * data, size_t len) { +std::string fnv_hash(const uint8_t* data, size_t len) { const uint64_t fnv_prime = 0x100000001b3ULL; uint64_t hash = 0xcbf29ce484222325ULL; @@ -2120,7 +1986,7 @@ static std::string fnv_hash(const uint8_t * data, size_t len) { return std::to_string(hash); } -static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector files) { +server_tokens process_mtmd_prompt(mtmd_context* mctx, std::string prompt, std::vector files) { mtmd::bitmaps bitmaps; for (auto& file : files) { mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size())); @@ -2163,7 +2029,7 @@ static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt * - "prompt": [12, 34, "string", 56, 78] * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } */ -static server_tokens tokenize_input_subprompt(const llama_vocab* vocab, mtmd_context* mctx, const json& json_prompt, bool add_special, bool parse_special) { +server_tokens tokenize_input_subprompt(const llama_vocab* vocab, mtmd_context* mctx, const json& json_prompt, bool add_special, bool parse_special) { constexpr char JSON_STRING_PROMPT_KEY[] = "prompt_string"; constexpr char JSON_MTMD_DATA_KEY[] = "multimodal_data"; const bool has_mtmd = mctx != nullptr; @@ -2214,7 +2080,7 @@ static server_tokens tokenize_input_subprompt(const llama_vocab* vocab, mtmd_con * - "prompt": [[12, 34, 56], [78, 90, 12]] * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}] */ -static std::vector tokenize_input_prompts(const llama_vocab* vocab, mtmd_context* mctx, const json& json_prompt, bool add_special, bool parse_special) { +std::vector tokenize_input_prompts(const llama_vocab* vocab, mtmd_context* mctx, const json& json_prompt, bool add_special, bool parse_special) { std::vector result; if (json_prompt.is_array() && !json_is_array_and_contains_numbers(json_prompt)) { result.reserve(json_prompt.size()); @@ -2231,7 +2097,7 @@ static std::vector tokenize_input_prompts(const llama_vocab* voca return result; } // Assuming raw_buffer has .data() and .size() members -inline void print_files_info(const std::vector& files) { +void print_files_info(const std::vector& files) { for (size_t i = 0; i < files.size(); ++i) { const auto& file = files[i]; std::cout << "File " << i << ": Size = " << file.size() << " bytes\n"; @@ -2246,8 +2112,8 @@ inline void print_files_info(const std::vector& files) { } } -inline bool prompt_cache_equal(llama_context* ctx, const server_tokens& cache_tokens, - const server_tokens& prompt_tokens, size_t start, const common_prefix & prefix ) { +bool prompt_cache_equal(llama_context* ctx, const server_tokens& cache_tokens, + const server_tokens& prompt_tokens, size_t start, const common_prefix& prefix) { std::string common_cache = cache_tokens.detokenize(ctx, true, start, prefix.first); std::string common_prompt = prompt_tokens.detokenize(ctx, true, start, prefix.second); bool equal = common_cache == common_prompt; diff --git a/examples/server/server-common.h b/examples/server/server-common.h new file mode 100644 index 00000000..2289a3c7 --- /dev/null +++ b/examples/server/server-common.h @@ -0,0 +1,455 @@ +#pragma once + +#include "common.h" +#include "log.h" +#include "llama.h" +#include +#include "chat.h" +#include "mtmd.h" +#include "mtmd-helper.h" +#define JSON_ASSERT GGML_ASSERT +#include + +#include +#include +#include + + + +// Change JSON_ASSERT from assert() to GGML_ASSERT: +#define JSON_ASSERT GGML_ASSERT +#include "base64.hpp" + + +#include +#include +#include +#include +#include + +// increase max payload length to allow use of larger context size +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 +// increase backlog size to avoid connection resets for >> 1 slots +#define CPPHTTPLIB_LISTEN_BACKLOG 512 +// increase max URI length to handle longer prompts in query string +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768 +// disable Nagle's algorithm +#define CPPHTTPLIB_TCP_NODELAY true +#include "httplib.h" + +#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" + +using json = nlohmann::ordered_json; + +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_CNT(slot, fmt, ...) LOG_CNT("" fmt, __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) + +#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__) +#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 +enum error_type { + ERROR_TYPE_INVALID_REQUEST, + ERROR_TYPE_AUTHENTICATION, + ERROR_TYPE_SERVER, + ERROR_TYPE_NOT_FOUND, + ERROR_TYPE_PERMISSION, + ERROR_TYPE_UNAVAILABLE, // custom error + ERROR_TYPE_NOT_SUPPORTED, // custom error +}; + +extern bool server_verbose; +extern bool server_log_json; + +#ifndef SERVER_VERBOSE +#define SERVER_VERBOSE 1 +#endif + +#if SERVER_VERBOSE != 1 +#define LOG_VERBOSE(MSG, ...) +#else +#define LOG_VERBOSE(MSG, ...) \ + do \ + { \ + if (server_verbose) \ + { \ + server_log("VERB", __func__, __LINE__, MSG, __VA_ARGS__); \ + } \ + } while (0) +#endif + +#define LOG_ERROR( MSG, ...) server_log("ERR", __func__, __LINE__, MSG, __VA_ARGS__) +#define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__) +#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) + +using raw_buffer = std::vector; + +void server_log(const char* level, const char* function, int line, const char* message, const json& extra); + +template +static T json_value(const json& body, const std::string& key, const T& default_value) { + // Fallback null to default value + if (body.contains(key) && !body.at(key).is_null()) { + try { + return body.at(key); + } + catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const& err) { + std::stringstream ss; + ss << "Wrong type supplied for parameter '" << key << "'. Expected '" << json(default_value).type_name() << "', using default value: " << err.what(); + LOG_WARNING(ss.str().c_str(), body); + return default_value; + } + } + else { + return default_value; + } +} + +// thin wrapper around common_grammar_trigger with (de)serialization functions +struct server_grammar_trigger { + common_grammar_trigger value; + + server_grammar_trigger() = default; + server_grammar_trigger(const common_grammar_trigger& value) : value(value) {} + server_grammar_trigger(const json& in); + + json to_json() const; +}; + + +// +// chat template utils +// + +// +// base64 utils (TODO: move to common in the future) +// + +static const std::string base64_chars = +"ABCDEFGHIJKLMNOPQRSTUVWXYZ" +"abcdefghijklmnopqrstuvwxyz" +"0123456789+/"; + +bool is_base64(uint8_t c); + +std::vector base64_decode(const std::string& encoded_string); + +// +// random string / id +// + +std::string random_string(); + +std::string gen_chatcmplid(); + +std::string gen_tool_call_id(); + +// +// other common utils +// +float get_slot_similarity(size_t lcp, size_t prompt_length, size_t cache_length); + +size_t common_part(const std::vector& a, const std::vector& b); + +size_t common_part(const std::string& a, const std::string& b); + +// return the last index of character that can form a valid string +// if the last character is potentially cut in half, return the index before the cut +// if validate_utf8(text) == text.size(), then the whole text is valid utf8 +size_t validate_utf8(const std::string& text); + +// TODO: reuse llama_detokenize + +std::string tokens_to_str(llama_context* ctx, const llama_tokens& tokens); + +// format incomplete utf-8 multibyte character for output +std::string tokens_to_output_formatted_string(const llama_context* ctx, const llama_token token); + +struct common_prefix { + size_t first = 0; + size_t second = 0; +}; + +common_prefix common_prefix_add(const common_prefix& a, const common_prefix& b); + +common_prefix find_common_string_prefix(const std::string& a_str, const std::string& b_str, const std::set& ignore_set); + +size_t find_n_tokens_from_string(const llama_context* ctx, const llama_tokens& a, const size_t max_size, size_t start, + std::vector& map); + +std::string remove_with_set(std::string str, const std::set& chars_to_remove); + +common_prefix find_largest_common_number(const std::vector& a_list, const std::vector& b_list); + +size_t find_n_tokens_from_string_with_ignore(const llama_context* ctx, const llama_tokens& a, const size_t max_size, size_t start, const std::set& ignore_set, + std::vector& map); + +common_prefix find_common_text_token_prefix(const llama_context* ctx, const llama_tokens& a, const llama_tokens& b, + size_t start, bool exact); + +struct completion_token_output { + llama_token tok; + std::string text_to_send; + float prob; + + struct prob_info { + llama_token tok; + std::string txt; + float prob; + }; + std::vector probs; + + json to_json(bool post_sampling_probs) const; + + static float logarithm(float x); + + static std::vector str_to_bytes(const std::string& str); + + static json probs_vector_to_json(const std::vector& probs, bool post_sampling_probs); +}; + +// convert a vector of completion_token_output to json +json probs_vector_to_json(const llama_context* ctx, const std::vector& probs); + +bool server_sent_event(httplib::DataSink& sink, const json& data); + +bool server_sent_anthropic_event(httplib::DataSink& sink, const json& data); + +// +// OAI utils +// +// used by /completions endpoint +json oaicompat_chat_params_parse(const json& body); + +struct oaicompat_parser_options { + bool use_jinja; + bool prefill_assistant; + common_reasoning_format reasoning_format; + std::map chat_template_kwargs; + common_chat_templates* tmpls; + bool allow_image; + bool allow_audio; + bool enable_thinking = true; +}; + +// used by /chat/completions endpoint +json oaicompat_chat_params_parse( + const struct llama_model* model, + json& body, /* openai api json semantics */ + const oaicompat_parser_options& opt, + std::vector& out_files); + +json anthropic_params_from_json( + const struct llama_model* model, + const json& body_in, /* anthropic messages api json semantics */ + const oaicompat_parser_options& opt, + std::vector& out_files); + + +// +// tokenizer and input processing utils +// + +bool json_is_array_of_numbers(const json& data); + +// is array having BOTH numbers & strings? +bool json_is_array_of_mixed_numbers_strings(const json& data); + +// does array have any individual integers/tokens? +bool json_is_array_and_contains_numbers(const json& data); + +// get value by path(key1 / key2) +json json_get_nested_values(const std::vector& paths, const json& js); + +/** + * this handles 2 cases: + * - only string, example: "string" + * - mixed string and tokens, example: [12, 34, "string", 56, 78] + */ +std::vector tokenize_mixed(const llama_vocab* vocab, const json& json_prompt, bool add_special, bool parse_special); + +json format_tokenizer_response(const std::vector& tokens); + +json format_detokenized_response(const std::string& content); + +json format_error_response(const std::string& message, const enum error_type type); + +struct token_probabilities { + float sampled_token_p; + std::vector cur; +}; + +token_probabilities get_token_probabilities(llama_context* ctx, int idx, llama_token sampled_token_id, int n_sorted); + +/** + * server_tokens is a helper to manage the input tokens and image for the server. + * it is made this way to simplify the logic of KV cache management. + */ +struct server_tokens { + bool has_mtmd = false; + +private: // disallow accessing these members directly, risking out-of-sync + + // map a **start** index in tokens to the image chunk + // note: the order need to be in-sync with tokens + std::map map_idx_to_media; + + // list of tokens + // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk + // otherwise, it is a normal text token + // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list + // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos + llama_tokens tokens; + + // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos): + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1] + // idx 0 1 2 3 4 5 6 7 8 9 10 + // pos 0 1 2 3 4 5 5 5 7 7 7 + // map_idx_to_media will contain: {5, img0}, {8, img1} + +public: + server_tokens() = default; + ~server_tokens() = default; + + // Prevent copying + server_tokens(const server_tokens&) = delete; + server_tokens& operator=(const server_tokens&) = delete; + + // Allow moving (usually implicitly generated if members are movable) + server_tokens(server_tokens&&) = default; + server_tokens& operator=(server_tokens&&) = default; + + // Allow accessing elements using [] operator + llama_token operator[](size_t index) { return tokens[index]; } + const llama_token& operator[](size_t index) const { return tokens[index]; } + + server_tokens(mtmd::input_chunks& mtmd_chunks, bool has_mtmd); + + server_tokens(const llama_tokens& tokens, bool has_mtmd); + + llama_pos pos_next() const; + + // for debugging + std::string str() const; + + const mtmd::input_chunk_ptr& find_chunk(size_t idx) const; + + void push_back(llama_token tok); + + // will create a copy of the chunk if it contains non-text data + void push_back(const mtmd_input_chunk* chunk); + + // appends server tokens, updates the media map. copies media chunks. + void push_back(server_tokens& tokens); + + // for compatibility with context shift and prompt truncation + void insert(const std::vector& inp_tokens); + + // for compatibility with context shift and prompt truncation + void resize(size_t size); + + llama_token* data(); + + llama_tokens::iterator begin(); + + llama_tokens::iterator end(); + + llama_tokens::const_iterator cbegin(); + + llama_tokens::const_iterator cend(); + + llama_tokens tokens_data(); + + // for compatibility with speculative decoding, ctx shift, slot save/load + const std::vector& get_text_tokens() const; + + // for compatibility with speculative decoding + void set_token(llama_pos pos, llama_token id); + + size_t size() const; + + bool empty() const; + + void clear(); + + void keep_first(size_t n); + + std::string detokenize(const llama_context* ctx, bool special) const; + + std::string detokenize(const llama_context* ctx, bool special, size_t start, size_t length) const; + + size_t find_n_from_tokens(const llama_context* ctx, const server_tokens& b, bool special, + size_t start, const size_t length); + + size_t get_common_prefix_exact(const server_tokens& b) const; + + + common_prefix get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact = false) const; + // take first n tokens of tokens list a + // find the common prefix between a and b + common_prefix get_common_prefix_first_n(const llama_context* ctx, const server_tokens& b, size_t n, bool exact = false) const; + + // make sure all text tokens are within the vocab range + bool validate(const struct llama_context* ctx) const; + + // encode and decode the image chunk + int32_t process_chunk( + llama_context* ctx, + mtmd_context* mctx, + size_t idx, + llama_pos pos, + int32_t seq_id, + size_t& n_tokens_out) const; + + // Keep the first n_keep and remove n_discard tokens from tokens + void discard_n_tokens(int32_t n_keep, int32_t n_discard); + + // Similarity between prompt and cached + float get_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const; + + // Similarity between common part and cache + float get_cached_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const; +}; + +// Computes FNV-1a hash of the data +std::string fnv_hash(const uint8_t* data, size_t len); + +server_tokens process_mtmd_prompt(mtmd_context* mctx, std::string prompt, std::vector files); + +/** + * break the input "prompt" object into multiple prompt if needed, then tokenize them + * use tokenize_input_prompts() if the input could be an array. + * this supports these cases: + * - "prompt": "string" + * - "prompt": [12, 34, 56] + * - "prompt": [12, 34, "string", 56, 78] + * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } + */ +server_tokens tokenize_input_subprompt(const llama_vocab* vocab, mtmd_context* mctx, const json& json_prompt, bool add_special, bool parse_special); + +/** + * break the input "prompt" object into multiple prompt if needed, then tokenize them + * this supports these cases: + * - "prompt": "string" + * - "prompt": [12, 34, 56] + * - "prompt": [12, 34, "string", 56, 78] + * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } + * and multiple prompts (multi-tasks): + * - "prompt": ["string1", "string2"] + * - "prompt": ["string1", [12, 34, 56]] + * - "prompt": [[12, 34, 56], [78, 90, 12]] + * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}] + */ +std::vector tokenize_input_prompts(const llama_vocab* vocab, mtmd_context* mctx, const json& json_prompt, bool add_special, bool parse_special); + +// Assuming raw_buffer has .data() and .size() members +void print_files_info(const std::vector& files); + +bool prompt_cache_equal(llama_context* ctx, const server_tokens& cache_tokens, + const server_tokens& prompt_tokens, size_t start, const common_prefix& prefix); diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp new file mode 100644 index 00000000..45ef8988 --- /dev/null +++ b/examples/server/server-context.cpp @@ -0,0 +1,2763 @@ +#include "server-context.h" +#include "server-common.h" +#include "server-task.h" +#include "server-queue.h" + +#include "common.h" +#include "llama.h" +#include "log.h" +#include "sampling.h" +#include "speculative.h" +#include "mtmd.h" +#include "mtmd-helper.h" + + +server_context::~server_context() { + if (ctx) { + llama_free(ctx); + ctx = nullptr; + } + + if (model) { + llama_free_model(model); + model = nullptr; + } + // Free multimodal + mtmd_free(mctx); + // Free draft model and context if they exist + if (ctx_draft) { + llama_free(ctx_draft); + ctx_draft = nullptr; + } + if (model_draft) { + llama_free_model(model_draft); + model_draft = nullptr; + } + + // Clear any sampling context + for (server_slot& slot : slots) { + if (slot.ctx_sampling != nullptr) { + llama_sampling_free(slot.ctx_sampling); + } + if (slot.ctx_dft) { + llama_free(slot.ctx_dft); + } + if (slot.spec) { + llama_speculative_free(slot.spec); + } + llama_batch_free(slot.batch_spec); + } + + llama_batch_free(batch); +} + +bool server_context::load_model(const gpt_params& params_) { + params = params_; + + llama_init_result llama_init = llama_init_from_gpt_params(params); + + model = llama_init.model; + ctx = llama_init.context; + lora_adapters = llama_init.lora_adapters; + + if (model == nullptr) { + LOG_ERROR("unable to load model", { {"model", params.model} }); + return false; + } + + n_ctx = llama_n_ctx(ctx); + + add_bos_token = llama_should_add_bos_token(model); + has_eos_token = llama_add_eos_token(model) != 1; + + chat_templates = common_chat_templates_init(model, params.chat_template); + try { + common_chat_format_example(chat_templates.get(), params.use_jinja, {}); + } + catch (const std::exception& e) { + LOG_WARNING("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); + chat_templates = common_chat_templates_init(model, "chatml"); + } + + bool has_draft_model = !params.model_draft.empty() || !params.draft_params.empty(); + std::string& mmproj_path = params.mmproj.path; + if (!mmproj_path.empty()) { + mtmd_context_params mparams = mtmd_context_params_default(); + mparams.use_gpu = params.mmproj_use_gpu; + mparams.print_timings = false; + mparams.n_threads = params.n_threads; + mparams.flash_attn_type = params.flash_attn ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED; + mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; + mparams.image_min_tokens = params.image_min_tokens; + mparams.image_max_tokens = params.image_max_tokens; + mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + if (mctx == nullptr) { + LOG_ERROR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); + return false; + } + LOG_INFO("loaded multimodal model, '%s'\n", mmproj_path.c_str()); + + if (params.ctx_shift) { + params.ctx_shift = false; + LOG_WARNING("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); + } + + //if (params.n_cache_reuse) { + // params_base.n_cache_reuse = 0; + // SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); + //} + + if (has_draft_model) { + LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal"); + return false; + } + } + // Load draft model for speculative decoding if specified + if (has_draft_model) { + LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n"); + + gpt_params params_dft; + params_dft.devices = params.devices_draft; + params_dft.model = params.model_draft; + params_dft.n_gpu_layers = params.n_gpu_layers_draft; + params_dft.cache_type_k = params.cache_type_k_draft.empty() ? params.cache_type_k : params.cache_type_k_draft; + params_dft.cache_type_v = params.cache_type_v_draft.empty() ? params.cache_type_v : params.cache_type_v_draft; + params_dft.flash_attn = params.flash_attn; + if (!params.draft_params.empty()) { + auto [argc, argv] = parse_command_line("llama-server " + params.draft_params); + if (!gpt_params_parse(argc, argv, params_dft)) { + gpt_params_print_usage(argc, argv, params_dft); + free_command_line(argc, argv); + return false; + }; + free_command_line(argc, argv); + } + LOG_INFO("", { {"model", params_dft.model} }); + if (params_dft.n_ctx == 0) { + params_dft.n_ctx = params.n_ctx_draft; + } + params_dft.n_ctx = params_dft.n_ctx == 0 ? params.n_ctx / params.n_parallel : params_dft.n_ctx; + params_dft.n_parallel = 1; + + llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft); + + llama_model* model_dft = llama_init_dft.model; + if (model_dft == nullptr) { + LOG_ERROR("failed to load draft model", { {"model", params.model_draft} }); + return false; + } + + if (!llama_speculative_are_compatible(ctx, llama_init_dft.context)) { + LOG_INFO("the draft model is not compatible with the target model. tokens will be translated between the draft and target models.", { {} }); + } + + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context); + + cparams_dft = llama_context_params_from_gpt_params(params_dft); + cparams_dft.n_batch = n_ctx_dft; + + model_draft = llama_init_dft.model; + ctx_draft = llama_init_dft.context; + } + return true; +} + +void server_context::init() { + const int32_t n_ctx_slot = n_ctx / params.n_parallel; + + LOG_INFO("initializing slots", { {"n_slots", params.n_parallel} }); + + for (int i = 0; i < params.n_parallel; i++) { + server_slot slot; + + slot.id = i; + slot.ctx = ctx; + slot.n_ctx = n_ctx_slot; + slot.n_predict = params.n_predict; + slot.mctx = mctx; + slot.cache_tokens.has_mtmd = mctx != nullptr; + + LOG_INFO("new slot", { + {"id_slot", slot.id}, + {"n_ctx_slot", slot.n_ctx} + }); + + const int ga_n = params.grp_attn_n; + const int ga_w = params.grp_attn_w; + + if (ga_n != 1) { + GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT + GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT + //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT + //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT + + LOG_INFO("slot self-extend", { + {"id_slot", slot.id}, + {"ga_n", ga_n}, + {"ga_w", ga_w} + }); + } + + slot.ga_i = 0; + slot.ga_n = ga_n; + slot.ga_w = ga_w; + + slot.sparams = params.sparams; + + // Initialize speculative decoding if a draft model is loaded + if (ctx_draft) { + slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + // slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft); // initialized twice + slot.ctx_dft = ctx_draft; + if (slot.ctx_dft == nullptr) { + LOG_ERROR("failed to create draft context", {}); + return; + } + + slot.spec = llama_speculative_init(ctx, slot.ctx_dft); + if (slot.spec == nullptr) { + LOG_ERROR("failed to create speculator", {}); + return; + } + for (auto& pair : params.replacements_draft) { + llama_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); + } + + } + + slot.reset(); + + slots.push_back(std::move(slot)); + } + + default_generation_settings_for_props = get_formated_generation(slots.front()); + default_generation_settings_for_props["seed"] = -1; + + // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens + // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) + { + const int32_t n_batch = llama_n_batch(ctx); + + // only a single seq_id per token is needed + batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1); + } + + metrics.init(); + + if (params.cache_ram_mib != 0) { + if (params.cache_ram_mib < 0) { + LLAMA_LOG_INFO("prompt cache is enabled, size limit: %s\n", "no limit"); + } + else { + LLAMA_LOG_INFO("prompt cache is enabled, size limit: %d MiB\n", params.cache_ram_mib); + } + LLAMA_LOG_INFO("%s", "use `--cache-ram 0` to disable the prompt cache\n"); + // only apply ram size limit. No token limit for now. + prompt_cache = std::make_unique(ctx, params.cache_ram_mib, 0); + } + else { + LLAMA_LOG_INFO("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); + } + + // thinking is enabled if: + // 1. It's not explicitly disabled (reasoning_budget == 0) + // 2. The chat template supports it + const bool enable_thinking = params.use_jinja && params.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); + //LLAMA_LOG_INFO("Enable thinking? %d\n", enable_thinking); + + oai_parser_opt = { + /* use_jinja */ params.use_jinja, + /* prefill_assistant */ params.prefill_assistant, + /* reasoning_format */ params.reasoning_format, + /* chat_template_kwargs */ params.default_template_kwargs, + /* common_chat_templates */ chat_templates.get(), + /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, + /* allow_audio */ mctx ? mtmd_support_audio(mctx) : false, + /* enable_thinking */ enable_thinking, + }; +} + + +void server_slot::prompt_save(server_prompt_cache& prompt_cache) const { + assert(server_cached_prompt.data.size() == 0); + + const size_t cur_size = llama_state_seq_get_size(ctx, id); + + LLAMA_LOG_INFO(" - saving prompt with length %d, total state size = %.3f MiB\n", + (int)server_cached_prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); + + auto* cur = prompt_cache.alloc(server_cached_prompt, cur_size); + if (cur == nullptr) { + return; + } + + llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id); +} + +void server_slot::prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) { + bool res = prompt_cache.load(server_cached_prompt, tokens, ctx, id); + if (!res) { + LLAMA_LOG_INFO("failed to load prompt from cache\n"); + } +} + +void server_slot::reset() { + n_prompt_tokens = 0; + generated_text = ""; + truncated = false; + stopped_eos = false; + stopped_word = false; + stopped_limit = false; + stopping_word = ""; + n_past = 0; + n_sent_text = 0; + n_sent_token_probs = 0; + infill = false; + ga_i = 0; + n_past_se = 0; + chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + + generated_token_probs.clear(); + + + // Reset speculative decoding stats + n_draft_total = 0; + n_draft_accepted = 0; + chat_msg = {}; + json_schema = json(); + generated_tool_call_ids.clear(); + + task.reset(); +} + +bool server_slot::has_budget(gpt_params& global_params) { + if (params.n_predict == -1 && global_params.n_predict == -1) { + return true; // limitless + } + + n_remaining = -1; + + if (params.n_predict != -1) { + n_remaining = params.n_predict - n_decoded; + } + else if (global_params.n_predict != -1) { + n_remaining = global_params.n_predict - n_decoded; + } + + return n_remaining > 0; // no budget +} + +bool server_slot::available() const { + return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE; +} + +bool server_slot::is_processing() const { + return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING; +} + +void server_slot::add_token_string(const completion_token_output& token) { + if (command == SLOT_COMMAND_RELEASE) { + return; + } + generated_token_probs.push_back(token); +} + +void server_slot::release() { + if (state == SLOT_STATE_PROCESSING) { + t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; + command = SLOT_COMMAND_RELEASE; + task.reset(); + } +} + + +json server_slot::get_formated_timings() const { + return json{ + {"prompt_n", n_prompt_tokens_processed}, + {"prompt_ms", t_prompt_processing}, + {"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed}, + {"prompt_per_second", 1e3 / t_prompt_processing * n_prompt_tokens_processed}, + + {"predicted_n", n_decoded}, + {"predicted_ms", t_token_generation}, + {"predicted_per_token_ms", t_token_generation / n_decoded}, + {"predicted_per_second", 1e3 / t_token_generation * n_decoded}, + + {"n_ctx", n_ctx}, + {"n_past", n_past}, + }; +} + +result_timings server_slot::get_timings() const { + result_timings timings; + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; + timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; + timings.predicted_per_token_ms = t_token_generation / n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + + timings.n_ctx = n_ctx; + timings.n_past = n_past; + + + // Add speculative metrics + if (n_draft_total > 0) { + timings.draft_n = n_draft_total; + timings.draft_n_accepted = n_draft_accepted; + } + + return timings; +} + +const common_chat_msg& server_slot::update_chat_msg(std::vector& diffs) { + auto previous_msg = chat_msg; + auto new_msg = common_chat_parse( + generated_text, + /* is_partial= */ stop != STOP_TYPE_EOS, + params.oaicompat_chat_syntax); + if (!new_msg.empty()) { + new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id); + chat_msg = new_msg; + diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); + } + //LLAMA_LOG_DEBUG("Parsing chat message: %s\n", generated_text.c_str()); + //LLAMA_LOG_DEBUG("Parsing chat message: %s\n", chat_msg.reasoning_content.c_str()); + //LLAMA_LOG_DEBUG("Parsing chat message: %s\n", chat_msg.content.c_str()); + return chat_msg; +} + + +size_t server_slot::find_stopping_strings(const std::string& text, const size_t last_token_size, bool is_full_stop) { + size_t stop_pos = std::string::npos; + + for (const std::string& word : params.antiprompt) { + size_t pos; + + if (is_full_stop) { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + + pos = text.find(word, from_pos); + } + else { + pos = string_find_partial_stop(text, word); + } + + if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { + if (is_full_stop) { + stopped_word = true; + stopping_word = word; + has_next_token = false; + } + stop_pos = pos; + } + } + + return stop_pos; +} + +void server_slot::print_timings() const { + char buffer[512]; + + double t_token = t_prompt_processing / n_prompt_tokens_processed; + double n_tokens_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + snprintf(buffer, 512, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)", + t_prompt_processing, n_prompt_tokens_processed, + t_token, n_tokens_second); + + LOG_INFO(buffer, { + {"id_slot", id}, + {"id_task", id_task}, + {"t_prompt_processing", t_prompt_processing}, + {"n_prompt_tokens_processed", n_prompt_tokens_processed}, + {"t_token", t_token}, + {"n_tokens_second", n_tokens_second}, + }); + + t_token = t_token_generation / n_decoded; + n_tokens_second = 1e3 / t_token_generation * n_decoded; + + snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)", + t_token_generation, n_decoded, + t_token, n_tokens_second); + + LOG_INFO(buffer, { + {"id_slot", id}, + {"id_task", id_task}, + {"t_token_generation", t_token_generation}, + {"n_decoded", n_decoded}, + {"t_token", t_token}, + {"n_tokens_second", n_tokens_second}, + }); + + snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation); + + LOG_INFO(buffer, { + {"id_slot", id}, + {"id_task", id_task}, + {"t_prompt_processing", t_prompt_processing}, + {"t_token_generation", t_token_generation}, + {"t_total", t_prompt_processing + t_token_generation}, + }); +} + +void server_metrics::init() { + t_start = ggml_time_us(); +} + +void server_metrics::on_prompt_eval(const server_slot& slot) { + n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; + n_prompt_tokens_processed += slot.n_prompt_tokens_processed; + t_prompt_processing += slot.t_prompt_processing; + t_prompt_processing_total += slot.t_prompt_processing; +} + +void server_metrics::on_prediction(const server_slot& slot) { + n_tokens_predicted_total += slot.n_decoded; + n_tokens_predicted += slot.n_decoded; + t_tokens_generation += slot.t_token_generation; + t_tokens_generation_total += slot.t_token_generation; +} + +void server_metrics::reset_bucket() { + n_prompt_tokens_processed = 0; + t_prompt_processing = 0; + n_tokens_predicted = 0; + t_tokens_generation = 0; +} + +std::vector server_context::tokenize(const json& json_prompt, bool add_special) const { + // TODO: currently, we tokenize using special tokens by default + // this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216) + // but it's better compared to completely ignoring ChatML and other chat templates + const bool TMP_FORCE_SPECIAL = true; + + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + std::vector prompt_tokens; + + if (json_prompt.is_array()) { + bool first = true; + for (const auto& p : json_prompt) { + if (p.is_string()) { + auto s = p.template get(); + + std::vector p; + if (first) { + p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL); + first = false; + } + else { + p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL); + } + + prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); + } + else { + if (first) { + first = false; + } + + prompt_tokens.push_back(p.template get()); + } + } + } + else { + auto s = json_prompt.template get(); + prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL); + } + + return prompt_tokens; +} + +server_slot* server_context::get_slot_by_id(int id) { + for (server_slot& slot : slots) { + if (slot.id == id) { + return &slot; + } + } + + return nullptr; +} + +server_slot* server_context::get_available_slot(const server_task& task) { + server_slot* ret = nullptr; + bool update_cache = false; + + // find the slot that has at least n% prompt similarity + if (ret == nullptr && slot_prompt_similarity != 0.0f) { + int max_lcp_len = 0; + float sim_best = 0; + + for (server_slot& slot : slots) { + // skip the slot if it is not available + if (!slot.available()) { + continue; + } + const auto& cache_tokens = slot.cache_tokens; + // skip the slot if it does not contains prompt + if (cache_tokens.empty()) { + continue; + } + // length of the Longest Common Prefix between the current slot's prompt and the input prompt + auto lcp_len = cache_tokens.get_common_prefix(slot.ctx, task.tokens); + // fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length + float sim_cur = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, 0, 0); + // handle context shift + if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && task.tokens.size() >= slot.n_ctx) { + float sim_cur_ctx_shift = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, slot.n_kept_prompt, slot.n_discarded_prompt); + if (sim_cur_ctx_shift > sim_cur) { + sim_cur = sim_cur_ctx_shift; + } + } + + // select the current slot if the criteria match + if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { + sim_best = sim_cur; + max_lcp_len = lcp_len.first; + ret = &slot; + } + } + if (ret != nullptr) { + LOG_VERBOSE("selected slot by lcp similarity", { + {"id_slot", ret->id}, + {"max_lcp_len", max_lcp_len}, + {"similarity", sim_best}, + }); + } + } + + // find the slot that has been least recently used + if (ret == nullptr) { + int64_t t_last = ggml_time_us(); + for (server_slot& slot : slots) { + // skip the slot if it is not available + if (!slot.available()) { + continue; + } + // select the current slot if the criteria match + if (slot.t_last_used < t_last) { + t_last = slot.t_last_used; + ret = &slot; + } + } + + if (ret != nullptr) { + LOG_VERBOSE("selected slot by lru", { + {"id_slot", ret->id}, + {"t_last", t_last}, + }); + } + } + if (ret) { + const auto& tokens = ret->cache_tokens; + float f_keep = 0.0f; + if (!tokens.empty()) { + if (ret->ga_n == 1 && ret->n_discarded_prompt > 0 && task.tokens.size() >= ret->n_ctx) { + f_keep = tokens.get_cached_tokens_similarity(ret->ctx, task.tokens, ret->params.n_keep + add_bos_token, ret->n_discarded_prompt); + } + else { + f_keep = tokens.get_cached_tokens_similarity(ret->ctx, task.tokens, 0, 0); + } + // if we are about to lose a large portion of the existing context - save it in the prompt cache + if (f_keep < cache_ram_similarity) { + update_cache = true; + } + } + update_cache = update_cache && prompt_cache; + // cache prompts only for completion tasks + update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; + + // don't update the cache if the slot's context is above cache_ram_n_min + update_cache = update_cache && tokens.size() >= cache_ram_n_min; + + // TODO: mtmd does not support prompt cache + update_cache = update_cache && (ret->mctx == nullptr); + + LLAMA_LOG_INFO("======== Prompt cache: cache size: %d, n_keep: %d, n_discarded_prompt: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n", + (int)tokens.size(), ret->n_kept_prompt, ret->n_discarded_prompt, cache_ram_n_min, f_keep, cache_ram_similarity); + if (update_cache) { + const int64_t t_start = ggml_time_us(); + LLAMA_LOG_INFO("updating prompt cache\n"); + ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens + ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt; + ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt; + + ret->prompt_save(*prompt_cache); + LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); + } + // has prompts saved earlier to load + if (prompt_cache && !prompt_cache->states.empty()) { + const int64_t t_start = ggml_time_us(); + ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens + ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt; + ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt; + + ret->prompt_load(*prompt_cache, task.tokens); + prompt_cache->update(); + + ret->cache_tokens = server_tokens(ret->server_cached_prompt.tokens.get_text_tokens(), false); // recover cache tokens + ret->n_discarded_prompt = ret->server_cached_prompt.n_discarded_prompt; + ret->n_kept_prompt = ret->server_cached_prompt.n_kept_prompt; + + LLAMA_LOG_INFO("prompt cache load took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); + } + } + return ret; +} + +bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) { + slot_params default_params; + // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) + llama_sampling_params default_sparams = params.sparams; + auto& data = task.data; + + if (data.count("__oaicompat") != 0) { + slot.oaicompat = true; + slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); + } + else { + slot.oaicompat = false; + slot.oaicompat_model = ""; + } + slot.params.timings_per_token = json_value(data, "timings_per_token", false); + slot.params.stream = json_value(data, "stream", false); + auto stream_opt = json_value(data, "stream_options", json::object()); + slot.params.include_usage = json_value(stream_opt, "include_usage", false); + slot.params.cache_prompt = json_value(data, "cache_prompt", true); + slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict)); + slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); + slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); + slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); + slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); + slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); + slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); + slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); + slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); + slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability); + slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold); + slot.sparams.top_n_sigma = json_value(data, "top_n_sigma", default_sparams.top_n_sigma); + slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); + slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); + slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); + slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); + slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier); + slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base); + slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length); + slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n); + slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); + slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); + slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); + slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); + slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); + slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); + slot.sparams.seed = json_value(data, "seed", default_sparams.seed); + slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + + slot.params.post_sampling_probs = json_value(data, "post_sampling_probs", default_params.post_sampling_probs); + + // speculative decoding parameters + slot.params.speculative.n_max = json_value(data, "speculative.n_max", params.n_draft); + slot.params.speculative.n_min = json_value(data, "speculative.n_min", params.n_draft_min); + slot.params.speculative.p_min = json_value(data, "speculative.p_min", params.p_draft_min); + + // Clamp speculative parameters + slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min); + slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0); + slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0); + + if (slot.sparams.penalty_last_n < -1) { + throw std::runtime_error("Error: repeat_last_n must be >= -1"); + } + + if (slot.sparams.dry_penalty_last_n < -1) { + throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); + } + + if (slot.sparams.penalty_last_n == -1) { + // note: should be the slot's context and not the full context, but it's ok + slot.sparams.penalty_last_n = llama_n_ctx(ctx); + } + + if (slot.sparams.dry_penalty_last_n == -1) { + slot.sparams.dry_penalty_last_n = llama_n_ctx(ctx); + + } + if (slot.sparams.dry_base < 1.0f) + { + slot.sparams.dry_base = default_sparams.dry_base; + } + + // sequence breakers for DRY + { + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); + if (slot.sparams.dry_sequence_breakers.empty()) { + send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + } + + // process "json_schema" and "grammar" + if (data.contains("json_schema") && !data.contains("grammar")) { + try { + auto schema = json_value(data, "json_schema", json::object()); + LLAMA_LOG_DEBUG("JSON schema: %s\n", schema.dump(2).c_str()); + slot.sparams.grammar = json_schema_to_grammar(schema); + LLAMA_LOG_DEBUG("Converted grammar: %s\n", slot.sparams.grammar.c_str()); + } + catch (const std::exception& e) { + throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); + } + } + else { + slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar); + LLAMA_LOG_DEBUG("Grammar: %s\n", slot.sparams.grammar.c_str()); + slot.sparams.grammar_lazy = json_value(data, "grammar_lazy", default_sparams.grammar_lazy); + LLAMA_LOG_DEBUG("Grammar lazy: %s\n", slot.sparams.grammar_lazy ? "true" : "false"); + } + + if (slot.params.cache_prompt && slot.ga_n != 1) { + LOG_WARNING("cache_prompt is not supported with group-attention", {}); + slot.params.cache_prompt = false; + } + + if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { + // Might be better to reject the request with a 400 ? + LOG_WARNING("Max tokens to predict exceeds server configuration", { + {"params.n_predict", slot.params.n_predict}, + {"slot.n_predict", slot.n_predict}, + }); + slot.params.n_predict = slot.n_predict; + } + + // infill + slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix); + slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix); + + // get prompt + if (!task.infill) { + // maybe not needed since prompt has been tokenized? + const auto& prompt = data.find("prompt"); + if (!slot.prompt_tokens.validate(ctx)) { + send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); + return false; + } + if (prompt == data.end()) { + send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST); + return false; + } + + if ((prompt->is_string()) || + (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) || + (prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) { + slot.prompt = *prompt; + } + else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) { + slot.prompt = prompt->at(0); + } + else { + send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST); + return false; + } + slot.prompt_tokens = std::move(task.tokens); + } + + // penalize user-provided tokens + { + slot.sparams.penalty_prompt_tokens.clear(); + slot.sparams.use_penalty_prompt_tokens = false; + + const auto& penalty_prompt = data.find("penalty_prompt"); + + if (penalty_prompt != data.end()) { + if (penalty_prompt->is_string()) { + const auto penalty_prompt_string = penalty_prompt->get(); + slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false); + + if (slot.params.n_predict > 0) { + slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict); + } + slot.sparams.use_penalty_prompt_tokens = true; + + LOG_VERBOSE("penalty_prompt_tokens", { + {"id_slot", slot.id}, + {"tokens", slot.sparams.penalty_prompt_tokens}, + }); + } + else if (penalty_prompt->is_array()) { + const auto n_tokens = penalty_prompt->size(); + slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict)); + + const int n_vocab = llama_n_vocab(model); + for (const auto& penalty_token : *penalty_prompt) { + if (penalty_token.is_number_integer()) { + const auto tok = penalty_token.get(); + if (tok >= 0 && tok < n_vocab) { + slot.sparams.penalty_prompt_tokens.push_back(tok); + } + } + } + slot.sparams.use_penalty_prompt_tokens = true; + + LOG_VERBOSE("penalty_prompt_tokens", { + {"id_slot", slot.id}, + {"tokens", slot.sparams.penalty_prompt_tokens}, + }); + } + } + } + { + auto it = data.find("chat_format"); + if (it != data.end()) { + slot.params.oaicompat_chat_syntax.format = static_cast(it->get()); + LLAMA_LOG_DEBUG("Chat format: %s\n", common_chat_format_name(slot.params.oaicompat_chat_syntax.format)); + } + else { + slot.params.oaicompat_chat_syntax.format = default_params.oaicompat_chat_syntax.format; + } + common_reasoning_format reasoning_format = params.reasoning_format; + if (data.contains("reasoning_format")) { + reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get()); + } + slot.params.oaicompat_chat_syntax.reasoning_format = reasoning_format; + slot.params.oaicompat_chat_syntax.reasoning_in_content = slot.params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); + slot.params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); + + slot.params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); + } + { + + const auto preserved_tokens = data.find("preserved_tokens"); + if (preserved_tokens != data.end()) { + for (const auto& t : *preserved_tokens) { + auto ids = llama_tokenize(model, t.get(), /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + LOG("Preserved token: %d\n", ids[0]); + slot.sparams.preserved_tokens.insert(ids[0]); + } + else { + // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. + LOG("Not preserved because more than 1 token: %s\n", t.get().c_str()); + } + } + } + const auto grammar_triggers = data.find("grammar_triggers"); + if (grammar_triggers != data.end()) { + for (const auto& t : *grammar_triggers) { + server_grammar_trigger ct(t); + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { + const auto& word = ct.value.value; + auto ids = llama_tokenize(model, word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + auto token = ids[0]; + if (std::find(slot.sparams.preserved_tokens.begin(), slot.sparams.preserved_tokens.end(), (llama_token)token) == slot.sparams.preserved_tokens.end()) { + throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); + } + LOG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); + common_grammar_trigger trigger; + trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; + trigger.value = word; + trigger.token = token; + slot.sparams.grammar_triggers.push_back(std::move(trigger)); + } + else { + LOG("Grammar trigger word: `%s`\n", word.c_str()); + slot.sparams.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word }); + } + } + else { + //slot.sparams.grammar_triggers.push_back(ct); + if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) { + LLAMA_LOG_DEBUG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str()); + } + else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) { + LLAMA_LOG_DEBUG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str()); + } + else { + throw std::runtime_error("Unknown grammar trigger type"); + } + slot.sparams.grammar_triggers.emplace_back(std::move(ct.value)); + } + } + } + + if (slot.sparams.grammar_lazy && slot.sparams.grammar_triggers.empty()) { + throw std::runtime_error("Error: no triggers set for lazy grammar!"); + } + } + + { + slot.sparams.logit_bias.clear(); + + if (json_value(data, "ignore_eos", false) && has_eos_token) { + slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + } + + const auto& logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_n_vocab(model); + for (const auto& el : *logit_bias) { + // TODO: we may want to throw errors here, in case "el" is incorrect + if (el.is_array() && el.size() == 2) { + float bias; + if (el[1].is_number()) { + bias = el[1].get(); + } + else if (el[1].is_boolean() && !el[1].get()) { + bias = -INFINITY; + } + else { + continue; + } + + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + slot.sparams.logit_bias[tok] = bias; + } + } + else if (el[0].is_string()) { + auto toks = llama_tokenize(model, el[0].get(), false); + for (auto tok : toks) { + slot.sparams.logit_bias[tok] = bias; + } + } + } + } + } + } + + { + slot.params.antiprompt.clear(); + + const auto& stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto& word : *stop) { + if (!word.empty()) { + slot.params.antiprompt.push_back(word); + } + } + } + } + + { + const auto samplers = data.find("samplers"); + if (samplers != data.end()) { + if (samplers->is_array()) { + slot.sparams.samplers_sequence = llama_sampling_types_from_names(*samplers, false); + } + else if (samplers->is_string()) { + slot.sparams.samplers_sequence = llama_sampling_types_from_chars(samplers->get()); + } + else { + slot.sparams.samplers_sequence = default_sparams.samplers_sequence; + } + } + } + + { + if (slot.ctx_sampling != nullptr) { + llama_sampling_free(slot.ctx_sampling); + } + slot.ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), slot.sparams); + if (slot.ctx_sampling == nullptr) { + // for now, the only error that may happen here is invalid grammar + send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + + slot.command = SLOT_COMMAND_LOAD_PROMPT; + // slot.prompt_tokens.clear(); + + LOG_INFO("slot is processing task", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + }); + + return true; +} + +void server_context::kv_cache_clear() { + LOG_VERBOSE("clearing KV cache", {}); + + // clear the entire KV cache + llama_kv_cache_clear(ctx); + clean_kv_cache = false; +} + +void server_context::system_prompt_update() { + LOG_VERBOSE("system prompt update", { + {"system_prompt", system_prompt}, + }); + + kv_cache_clear(); + system_tokens.clear(); + + if (!system_prompt.empty()) { + system_tokens = ::llama_tokenize(ctx, system_prompt, true); + + const int32_t n_batch = llama_n_batch(ctx); + const int32_t n_tokens_prompt = system_tokens.size(); + + for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i); + + llama_batch_clear(batch); + + for (int32_t j = 0; j < n_tokens; ++j) { + llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false); + } + + if (llama_decode(ctx, batch) != 0) { + LOG_ERROR("llama_decode() failed", {}); + return; + } + } + + // assign the system KV cache to all parallel sequences + for (int32_t i = 1; i <= params.n_parallel; ++i) { + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + } + } + + system_need_update = false; +} + +bool server_context::system_prompt_set(const std::string& sys_prompt) { + system_prompt = sys_prompt; + + LOG_VERBOSE("system prompt process", { + {"system_prompt", system_prompt}, + }); + + // release all slots + for (server_slot& slot : slots) { + slot.release(); + } + + system_need_update = true; + return true; +} + +bool server_context::process_token(completion_token_output& result, server_slot& slot) { + // remember which tokens were sampled - used for repetition penalties during sampling + const std::string token_str = result.text_to_send; + slot.sampled = result.tok; + + // search stop word and delete it + slot.generated_text += token_str; + slot.has_next_token = true; + + if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) { + // we can change penalty_prompt_tokens because it is always created from scratch each request + slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); + } + + // check if there is incomplete UTF-8 character at the end + bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); + + if (!incomplete) { + size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); + + const std::string str_test = slot.generated_text.substr(pos); + bool send_text = true; + + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); + if (stop_pos != std::string::npos) { + slot.generated_text.erase( + slot.generated_text.begin() + pos + stop_pos, + slot.generated_text.end()); + pos = std::min(slot.n_sent_text, slot.generated_text.size()); + } + else if (slot.has_next_token && !llama_token_is_eog(model, result.tok)) { + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); + send_text = stop_pos == std::string::npos; + } + + // check if there is any token to predict + if (send_text) { + // no send the stop word in the response + result.text_to_send = slot.generated_text.substr(pos, std::string::npos); + slot.n_sent_text += result.text_to_send.size(); + // add the token to slot queue and cache + } + else { + result.text_to_send = ""; + } + + slot.add_token_string(result); + if (slot.params.stream) { + send_partial_response(slot, result); + } + } + + if (incomplete) { + slot.has_next_token = true; + } + + // check the limits + if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) { + slot.stopped_limit = true; + slot.has_next_token = false; + + LOG_VERBOSE("stopped by limit", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"n_decoded", slot.n_decoded}, + {"n_predict", slot.params.n_predict}, + }); + } + + if (llama_token_is_eog(model, result.tok)) { + slot.stopped_eos = true; + slot.has_next_token = false; + + LOG_VERBOSE("eos token found", {}); + } + + auto n_ctx_train = llama_n_ctx_train(model); + if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 + && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + LOG_WARNING("n_predict is not set and self-context extend is disabled." + " Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", { + { "id_slot", slot.id }, + { "params.n_predict", slot.params.n_predict }, + { "slot.n_prompt_tokens", slot.n_prompt_tokens }, + { "slot.n_decoded", slot.n_decoded }, + { "slot.n_predict", slot.n_predict }, + { "n_slots", params.n_parallel }, + { "slot.n_ctx", slot.n_ctx }, + { "n_ctx", n_ctx }, + { "n_ctx_train", n_ctx_train }, + { "ga_n", slot.ga_n }, + }); + slot.truncated = true; + slot.stopped_limit = true; + slot.has_next_token = false; // stop prediction + } + + LOG_VERBOSE("next token", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"token", result.tok}, + {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, + {"has_next_token", slot.has_next_token}, + {"n_remain", slot.n_remaining}, + {"n_decoded", slot.n_decoded}, + {"stopped_eos", slot.stopped_eos}, + {"stopped_word", slot.stopped_word}, + {"stopped_limit", slot.stopped_limit}, + {"stopping_word", slot.stopping_word}, + }); + + return slot.has_next_token; // continue +} + +void server_context::populate_token_probs(const server_slot& slot, completion_token_output& result, bool post_sampling, bool special, int idx) { + size_t n_probs = slot.sparams.n_probs; + size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); + + if (post_sampling) { + const auto* cur_p = llama_sampling_get_candidates(slot.ctx_sampling); + const size_t max_probs = cur_p->size; + + // set probability for sampled token + for (size_t i = 0; i < max_probs; i++) { + if (cur_p->data[i].id == result.tok) { + result.prob = cur_p->data[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(max_probs); + for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { + result.probs.push_back({ + cur_p->data[i].id, + llama_detokenize(ctx, {cur_p->data[i].id}, special), + cur_p->data[i].p + }); + } + } + else { + auto&& [sampled_token_p, cur] = get_token_probabilities(ctx, idx, result.tok, n_probs); + + // set probability for sampled token + result.prob = sampled_token_p; + + // set probability for top n_probs tokens + result.probs.reserve(n_probs); + for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { + result.probs.push_back({ + cur[i].id, + llama_detokenize(ctx, {cur[i].id}, special), + cur[i].p + }); + } + } +} + +json server_context::get_formated_generation(const server_slot& slot) const { + const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); + const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); + + std::vector samplers_sequence; + samplers_sequence.reserve(slot.sparams.samplers_sequence.size()); + for (const auto& sampler_type : slot.sparams.samplers_sequence) { + samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type)); + } + + auto grammar_triggers = json::array(); + for (const auto& trigger : slot.sparams.grammar_triggers) { + grammar_triggers.push_back(trigger.to_json()); + } + + return json{ + {"n_ctx", slot.n_ctx}, + {"n_predict", slot.n_predict}, // Server configured n_predict + {"model", params.model_alias}, + {"seed", slot.sparams.seed}, + {"temperature", slot.sparams.temp}, + {"dynatemp_range", slot.sparams.dynatemp_range}, + {"dynatemp_exponent", slot.sparams.dynatemp_exponent}, + {"top_k", slot.sparams.top_k}, + {"top_p", slot.sparams.top_p}, + {"min_p", slot.sparams.min_p}, + {"tfs_z", slot.sparams.tfs_z}, + {"typical_p", slot.sparams.typical_p}, + {"repeat_last_n", slot.sparams.penalty_last_n}, + {"repeat_penalty", slot.sparams.penalty_repeat}, + {"presence_penalty", slot.sparams.penalty_present}, + {"frequency_penalty", slot.sparams.penalty_freq}, + {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, + {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, + {"dry_multiplier", slot.sparams.dry_multiplier}, + {"dry_base", slot.sparams.dry_base}, + {"dry_allowed_length", slot.sparams.dry_allowed_length}, + {"dry_penalty_last_n", slot.sparams.dry_penalty_last_n}, + {"dry_sequence_breakers", slot.sparams.dry_sequence_breakers}, + {"mirostat", slot.sparams.mirostat}, + {"mirostat_tau", slot.sparams.mirostat_tau}, + {"mirostat_eta", slot.sparams.mirostat_eta}, + {"penalize_nl", slot.sparams.penalize_nl}, + {"stop", slot.params.antiprompt}, + {"max_tokens", slot.params.n_predict}, // User configured n_predict + {"n_keep", slot.params.n_keep}, + {"n_discard", slot.params.n_discard}, + {"ignore_eos", ignore_eos}, + {"stream", slot.params.stream}, + {"logit_bias", slot.sparams.logit_bias}, + {"n_probs", slot.sparams.n_probs}, + {"min_keep", slot.sparams.min_keep}, + {"grammar", slot.sparams.grammar}, + {"grammar_triggers", grammar_triggers}, + {"preserved_tokens", slot.sparams.preserved_tokens}, + {"chat_format", common_chat_format_name(slot.params.oaicompat_chat_syntax.format)}, + {"reasoning_format", common_reasoning_format_name(slot.params.oaicompat_chat_syntax.reasoning_format)}, + {"reasoning_in_content", slot.params.oaicompat_chat_syntax.reasoning_in_content}, + {"thinking_forced_open", slot.params.oaicompat_chat_syntax.thinking_forced_open}, + {"samplers", samplers_sequence} + }; +} + +void server_context::send_error(const server_task& task, const std::string& error, const enum error_type type) { + send_error(task.id, task.id_multi, error, type); +} + +void server_context::send_error(const server_slot& slot, const std::string& error, const enum error_type type) { + send_error(slot.id_task, slot.id_multi, error, type); +} + +void server_context::send_error(const int id_task, const int id_multi, const std::string& error, const enum error_type type ) { + LOG_ERROR("task error", { + {"id_multi", id_multi}, + {"id_task", id_task}, + {"error", error}, + }); + + server_task_result res; + res.id = id_task; + res.id_multi = id_multi; + res.stop = false; + res.error = true; + res.data = format_error_response(error, type); + + queue_results.send(res); +} + +// if multimodal is enabled, send an error and return false +bool server_context::ensure_no_mtmd(const int id_task) { + if (mctx) { + int id_multi = 0; + send_error(id_task, id_multi, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + return false; + } + return true; +} + +void server_context::send_partial_response(server_slot& slot, completion_token_output tkn) { + server_task_result res; + res.final_result = false; + res.id = slot.id_task; + res.id_multi = slot.id_multi; + res.error = false; + res.stop = false; + res.stream = slot.params.stream; + res.content = tkn.text_to_send; + res.post_sampling_probs = slot.params.post_sampling_probs; + res.oaicompat = slot.params.oaicompat; + res.oaicompat_model = slot.params.oaicompat_model; + res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res.n_decoded = slot.n_decoded; + res.n_prompt_tokens = slot.n_prompt_tokens; + res.data = json{ + {"content", tkn.text_to_send}, + {"stop", false}, + {"id_slot", slot.id}, + {"multimodal", false} + }; + slot.update_chat_msg(res.oaicompat_msg_diffs); + + // populate res.probs_output + if (slot.sparams.n_probs > 0) { + res.probs_output = { tkn }; // copy the token probs + res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output); + } + + if (slot.oaicompat) { + res.data["oaicompat_token_ctr"] = slot.n_decoded; + res.data["model"] = slot.oaicompat_model; + } + + // populate timings if this is final response or timings_per_token is enabled + if (slot.params.timings_per_token) { + res.timings = slot.get_timings(); + } + queue_results.send(std::move(res)); +} + +void server_context::send_final_response(server_slot& slot) { + server_task_result res; + res.final_result = true; + res.id = slot.id_task; + res.id_multi = slot.id_multi; + res.error = false; + res.stop = true; // to do: set value + res.stream = slot.params.stream; + res.include_usage = slot.params.include_usage; + res.content = slot.generated_text; + res.timings = slot.get_timings(); + res.post_sampling_probs = slot.params.post_sampling_probs; + res.oaicompat = slot.params.oaicompat; + res.oaicompat_model = slot.params.oaicompat_model; + res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res.oaicompat_msg = slot.update_chat_msg(res.oaicompat_msg_diffs); + res.n_decoded = slot.n_decoded; + res.n_prompt_tokens = slot.n_prompt_tokens; + res.oaicompat_model = slot.oaicompat_model; + res.data = json{ + {"content", !slot.params.stream ? slot.generated_text : ""}, + {"generated_text", slot.generated_text}, // Always include full text for finish_reason logic + {"id_slot", slot.id}, + {"stop", true}, + {"model", params.model_alias}, + {"tokens_predicted", slot.n_decoded}, + {"tokens_evaluated", slot.n_prompt_tokens}, + {"generation_settings", get_formated_generation(slot)}, + {"prompt", slot.prompt}, + {"truncated", slot.truncated}, + {"stopped_eos", slot.stopped_eos}, + {"stopped_word", slot.stopped_word}, + {"stopped_limit", slot.stopped_limit}, + {"stopping_word", slot.stopping_word}, + {"tokens_cached", slot.n_past}, + {"timings", slot.get_formated_timings()}, + //{"oaicompat_chat_format", slot.params.oaicompat_chat_format}, + }; + + // populate res.probs_output + if (slot.sparams.n_probs > 0) { + res.probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end()); + res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output); + } + + if (slot.oaicompat) { + res.data["oaicompat_token_ctr"] = slot.n_decoded; + res.data["model"] = slot.oaicompat_model; + } + + queue_results.send(std::move(res)); +} + +void server_context::send_embedding(const server_slot& slot, const llama_batch& batch) { + server_task_result res; + res.id = slot.id_task; + res.id_multi = slot.id_multi; + res.error = false; + res.stop = true; + + const int n_embd = llama_n_embd(model); + + std::vector embd_res(n_embd, 0.0f); + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float* embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + LOG_ERROR("failed to get embeddings", { + {"token", batch.token[i]}, + {"seq_id", batch.seq_id[i][0]} + }); + + res.data = json{ + {"embedding", std::vector(n_embd, 0.0f)}, + {"tokens_evaluated", slot.n_prompt_tokens}, + }; + + continue; + } + + llama_embd_normalize(embd, embd_res.data(), n_embd); + + res.data = json{ + {"embedding", embd_res}, + {"tokens_evaluated", slot.n_prompt_tokens}, + }; + } + + queue_results.send(res); +} + +void server_context::request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs) { + server_task task; + task.id = id_task; + task.id_multi = id_multi; + task.id_target = 0; + task.data = std::move(data); + task.infill = infill; + task.embedding = embedding; + task.type = SERVER_TASK_TYPE_COMPLETION; + task.tokens = std::move(inputs); + // when a completion task's prompt array is not a singleton, we split it into multiple requests + // otherwise, it's a single-prompt task, we actually queue it + // if there's numbers in the prompt array it will be treated as an array of tokens + if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) { + bool numbers = false; + for (const auto& e : task.data.at("prompt")) { + if (e.is_number()) { + numbers = true; + break; + } + } + + // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers, + // it will completely stall the server. I don't know where the bug for this is. + // + // if there are numbers, it needs to be treated like a single prompt, + // queue_tasks handles a mix of strings and numbers just fine. + if (numbers) { + queue_tasks.post(std::move(task)); + } + else { + split_multiprompt_task(id_task, task); + } + } + else { + queue_tasks.post(std::move(task)); + } +} + +void server_context::request_cancel(int id_task) { + server_task task; + task.type = SERVER_TASK_TYPE_CANCEL; + task.id_target = id_task; + + queue_tasks.post(std::move(task)); +} + +void server_context::split_multiprompt_task(int id_multi, server_task& multiprompt_task) { + const int prompt_count = multiprompt_task.data.at("prompt").size(); + if (prompt_count <= 1) { + send_error(multiprompt_task, "error while handling multiple prompts"); + return; + } + + // generate all the ID for subtask + std::vector subtask_ids(prompt_count); + for (int i = 0; i < prompt_count; i++) { + subtask_ids[i] = queue_tasks.get_new_id(); + } + + // queue up the multitask so we can track its subtask progression + queue_tasks.add_multitask(id_multi, subtask_ids); + + // add subtasks + for (int i = 0; i < prompt_count; i++) { + json subtask_data = multiprompt_task.data; + subtask_data["prompt"] = subtask_data.at("prompt")[i]; + + // subtasks inherit everything else (infill mode, embedding mode, etc.) + request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding, + std::move(multiprompt_task.tokens)); + } +} + +void server_context::process_single_task(server_task&& task) { + switch (task.type) { + case SERVER_TASK_TYPE_COMPLETION: + { + const int id_slot = json_value(task.data, "id_slot", -1); + + server_slot* slot; + + if (id_slot != -1) { + slot = get_slot_by_id(id_slot); + } + else { + slot = get_available_slot(task); + } + + if (slot == nullptr) { + // if no slot is available, we defer this task for processing later + LOG_VERBOSE("no slot is available", { {"id_task", task.id} }); + queue_tasks.defer(std::move(task)); + break; + } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} }); + queue_tasks.defer(std::move(task)); + break; + } + + if (task.data.contains("system_prompt")) { + std::string sys_prompt = json_value(task.data, "system_prompt", std::string()); + system_prompt_set(sys_prompt); + + for (server_slot& slot : slots) { + slot.n_past = 0; + slot.n_past_se = 0; + } + } + + slot->reset(); + + slot->id_task = task.id; + slot->id_multi = task.id_multi; + slot->infill = task.infill; + slot->embedding = task.embedding; + + if (!launch_slot_with_task(*slot, task)) { + LOG_ERROR("error while launching slot", task.data); + break; + } + } break; + case SERVER_TASK_TYPE_CANCEL: + { + // release slot linked with the task id + for (auto& slot : slots) { + if (slot.id_task == task.id_target) { + slot.release(); + break; + } + } + } break; + case SERVER_TASK_TYPE_NEXT_RESPONSE: + { + // do nothing + } break; + case SERVER_TASK_TYPE_METRICS: + { + json slots_data = json::array(); + + int n_idle_slots = 0; + int n_processing_slots = 0; + + for (server_slot& slot : slots) { + json slot_data = get_formated_generation(slot); + slot_data["id"] = slot.id; + slot_data["id_task"] = slot.id_task; + slot_data["state"] = slot.state; + slot_data["prompt"] = slot.prompt; + slot_data["next_token"] = { + {"has_next_token", slot.has_next_token}, + {"n_remain", slot.n_remaining}, + {"n_decoded", slot.n_decoded}, + {"stopped_eos", slot.stopped_eos}, + {"stopped_word", slot.stopped_word}, + {"stopped_limit", slot.stopped_limit}, + {"stopping_word", slot.stopping_word}, + }; + + if (slot_data["state"] == SLOT_STATE_IDLE) { + n_idle_slots++; + } + else { + n_processing_slots++; + } + + slots_data.push_back(slot_data); + } + LOG_INFO("slot data", { + {"id_task", task.id}, + {"n_idle_slots", n_idle_slots}, + {"n_processing_slots", n_processing_slots} + }); + + LOG_VERBOSE("slot data", { + {"id_task", task.id}, + {"n_idle_slots", n_idle_slots}, + {"n_processing_slots", n_processing_slots}, + {"slots", slots_data} + }); + + server_task_result res; + res.id = task.id; + res.id_multi = task.id_multi; + res.stop = true; + res.error = false; + res.data = { + { "idle", n_idle_slots }, + { "processing", n_processing_slots }, + { "deferred", queue_tasks.queue_tasks_deferred.size() }, + { "t_start", metrics.t_start}, + + { "n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total}, + { "t_tokens_generation_total", metrics.t_tokens_generation_total}, + { "n_tokens_predicted_total", metrics.n_tokens_predicted_total}, + { "t_prompt_processing_total", metrics.t_prompt_processing_total}, + + { "n_prompt_tokens_processed", metrics.n_prompt_tokens_processed}, + { "t_prompt_processing", metrics.t_prompt_processing}, + { "n_tokens_predicted", metrics.n_tokens_predicted}, + { "t_tokens_generation", metrics.t_tokens_generation}, + + { "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)}, + { "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)}, + + { "slots", slots_data }, + }; + + if (json_value(task.data, "reset_bucket", false)) { + metrics.reset_bucket(); + } + queue_results.send(res); + } break; + case SERVER_TASK_TYPE_SLOT_SAVE: + { + if (!ensure_no_mtmd(task.id)) { + break; + } + int id_slot = task.data.at("id_slot"); + server_slot* slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} }); + queue_tasks.defer(std::move(task)); + break; + } + + const size_t token_count = slot->cache_tokens.size(); + const int64_t t_start = ggml_time_us(); + + std::string filename = task.data.at("filename"); + std::string filepath = task.data.at("filepath"); + + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); + + const int64_t t_end = ggml_time_us(); + const double t_save_ms = (t_end - t_start) / 1000.0; + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json{ + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", token_count }, // tokens saved + { "n_written", nwrite }, // bytes written + { "timings", { + { "save_ms", t_save_ms } + } } + }; + queue_results.send(result); + } break; + case SERVER_TASK_TYPE_SLOT_RESTORE: + { + if (!ensure_no_mtmd(task.id)) break; + int id_slot = task.data.at("id_slot"); + server_slot* slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} }); + queue_tasks.defer(std::move(task)); + break; + } + + const int64_t t_start = ggml_time_us(); + + std::string filename = task.data.at("filename"); + std::string filepath = task.data.at("filepath"); + + slot->cache_tokens.resize(slot->n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); + if (nread == 0) { + slot->cache_tokens.resize(0); + send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); + break; + } + slot->cache_tokens.resize(token_count); + + const int64_t t_end = ggml_time_us(); + const double t_restore_ms = (t_end - t_start) / 1000.0; + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json{ + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", token_count }, // tokens restored + { "n_read", nread }, // bytes read + { "timings", { + { "restore_ms", t_restore_ms } + } } + }; + queue_results.send(result); + } break; + case SERVER_TASK_TYPE_SLOT_ERASE: + { + if (!ensure_no_mtmd(task.id)) break; + int id_slot = task.data.at("id_slot"); + server_slot* slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", { {"id_task", task.id} }); + queue_tasks.defer(std::move(task)); + break; + } + + // Erase token cache + const size_t n_erased = slot->cache_tokens.size(); + llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); + slot->cache_tokens.clear(); + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json{ + { "id_slot", id_slot }, + { "n_erased", n_erased } + }; + queue_results.send(result); + } break; + case SERVER_TASK_TYPE_SET_LORA: + { + llama_lora_adapters_apply(ctx, lora_adapters); + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json{ { "success", true } }; + queue_results.send(result); + } break; + } +} + +void server_context::on_finish_multitask(const server_task_multi& multitask) { + // all subtasks done == multitask is done + server_task_result result; + result.id = multitask.id; + result.stop = true; + result.error = false; + + // collect json results into one json result + std::vector result_jsons; + for (const auto& subres : multitask.results) { + result_jsons.push_back(subres.data); + result.error = result.error && subres.error; + } + result.data = json{ + { "results", result_jsons } + }; + + queue_results.send(result); +} + +void server_context::print_tokens(const server_tokens& prompt, const server_tokens& cache, size_t start1, size_t start2, size_t length) { + if (cache.size() > start2) { + LLAMA_LOG_INFO("cache : %s\n", cache.detokenize(ctx, true, start2, length).c_str()); + } + if (prompt.size() > start1) { + LLAMA_LOG_INFO("prompt: %s\n", prompt.detokenize(ctx, true, start1, length).c_str()); + } + +} + +void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard) { + llama_kv_cache_seq_rm(ctx, slot.id, n_keep, n_keep + n_discard); + llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); + if (slot.params.cache_prompt) { + slot.cache_tokens.discard_n_tokens(n_keep, n_discard); + } +} + +// convert keep first few and discard next tokens in a to b +void server_context::context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep, + int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact) { + + common_prefix ctx_keep_prefix = a.get_common_prefix_first_n(ctx, b, n_keep, exact); + common_prefix ctx_total_discard_prefix = a.get_common_prefix_first_n(ctx, b, n_discard + n_keep, exact); + // only if there is enough common token + int32_t discard_offset = ctx_total_discard_prefix.first - (n_discard + n_keep); + int32_t keep_offset = ctx_keep_prefix.first - n_keep; + n_kept = ctx_keep_prefix.second - keep_offset; + n_discarded = ctx_total_discard_prefix.second - ctx_keep_prefix.second - discard_offset; + if (n_kept < 0) { + n_kept = n_keep; + } + if (n_discarded < 0) { + n_discarded = n_discard; + } +} + +void server_context::context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact) { + //server_tokens prompt_tokens = std::move(slot.prompt_tokens); + int n_keep = std::max(0, slot.params.n_keep + add_bos_token); + const int n_left = slot.n_ctx - n_keep; + int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); + + int n_discard_prompt = 0; + // we still need to truncate input since we have not discarded enough tokens + while (slot.n_prompt_tokens - slot.n_discarded_prompt >= slot.n_ctx) { + slot.n_discarded_prompt = slot.n_discarded_prompt + n_discard; + n_discard_prompt = n_discard_prompt + n_discard; + } + + // Handle mistokenization between prompt and cache during context shift + // + int32_t n_discard_cache = n_discard_prompt; + int32_t n_kept = n_keep; + slot.prompt_tokens.discard_n_tokens(n_keep, slot.n_discarded_prompt - n_discard_prompt); + if (n_discard_prompt > 0) { + context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep, + n_discard, n_kept, n_discard_cache, exact); + } + + int n_discard_cache_max = std::max((int32_t)slot.cache_tokens.size() - n_kept, 0); + n_discard_cache = std::min(n_discard_cache, n_discard_cache_max); + // discard matching tokens from cache and kv cache to avoid reprocessing the prompt + if (n_discard_cache > 0) { + discard_n_kv_and_cache_tokens(ctx, slot, n_kept, n_discard_cache); + } + // discard extra tokens from prompts + slot.n_kept_prompt = n_keep; + slot.prompt_tokens.discard_n_tokens(n_keep, n_discard_prompt); + slot.n_prompt_tokens = slot.prompt_tokens.size(); +} + +void server_context::update_slots() { + if (system_need_update) { + system_prompt_update(); + } + + // release slots + for (auto& slot : slots) { + if (slot.command == SLOT_COMMAND_RELEASE) { + slot.state = SLOT_STATE_IDLE; + slot.command = SLOT_COMMAND_NONE; + slot.t_last_used = ggml_time_us(); + + LOG_INFO("slot released", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"n_ctx", n_ctx}, + {"n_past", slot.n_past}, + {"n_system_tokens", system_tokens.size()}, + {"n_cache_tokens", slot.cache_tokens.size()}, + {"truncated", slot.truncated} + }); + + queue_tasks.notify_slot_changed(); + } + } + + // check if all slots are idle + { + bool all_idle = true; + + for (auto& slot : slots) { + if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) { + all_idle = false; + break; + } + } + + if (all_idle) { + LOG_INFO("all slots are idle", {}); + if (system_prompt.empty() && clean_kv_cache) { + kv_cache_clear(); + } + + return; + } + } + + { + LOG_VERBOSE("posting NEXT_RESPONSE", {}); + + server_task task; + task.type = SERVER_TASK_TYPE_NEXT_RESPONSE; + task.id_target = -1; + + queue_tasks.post(std::move(task)); + } + + // apply context-shift if needed + // TODO: simplify and improve + for (server_slot& slot : slots) { + if (slot.ga_n == 1) { + if (slot.is_processing() && (int)system_tokens.size() + slot.n_past >= slot.n_ctx - 1) { + if (!params.ctx_shift) { + // this check is redundant (for good) + // we should never get here, because generation should already stopped in process_token() + send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + if (mctx) { + // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded + // we don't support ctx_shift because an image chunk may contains multiple tokens + GGML_ABORT("not supported by multimodal"); + } + // Shift context + int n_keep = slot.params.n_keep < 0 ? slot.prompt_tokens.size() : slot.params.n_keep; + if (add_bos_token) { + n_keep += 1; + } + n_keep = std::min(slot.n_ctx - 4, n_keep); + + const int n_left = (int)system_tokens.size() + slot.n_past - n_keep; + const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); + int32_t n_kept; + int32_t n_discard_cache; + if (n_discard > 0) { + context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep, + n_discard, n_kept, n_discard_cache); + LOG_INFO("slot context shift", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"n_keep", n_keep}, + {"n_left", n_left}, + {"n_discard", n_discard}, + {"n_ctx", n_ctx}, + {"n_past", slot.n_past}, + {"n_system_tokens", system_tokens.size()}, + {"n_cache_tokens", slot.cache_tokens.size()} + }); + slot.n_discarded_prompt = slot.n_discarded_prompt + n_discard; + slot.n_kept_prompt = n_keep; + discard_n_kv_and_cache_tokens(ctx, slot, n_kept, n_discard_cache); + slot.n_past -= n_discard_cache; + slot.truncated = true; + } + + } + } + } + + // start populating the batch for this iteration + llama_batch_clear(batch); + + auto accept_special_token = [&](server_slot& slot, llama_token token) { + return params.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end(); + }; + + // frist, add sampled tokens from any ongoing sequences + for (auto& slot : slots) { + if (slot.state == SLOT_STATE_IDLE) { + continue; + } + + slot.i_batch = batch.n_tokens; + + const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; + + // TODO: we always have to take into account the "system_tokens" + // this is not great and needs to be improved somehow + llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.cache_tokens.pos_next(), { slot.id }, true); + + slot.n_past += 1; + + if (slot.params.cache_prompt) { + slot.cache_tokens.push_back(slot.sampled); + } + + LOG_VERBOSE("slot decode token", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"n_ctx", n_ctx}, + {"n_past", slot.n_past}, + {"n_system_tokens", system_tokens.size()}, + {"n_cache_tokens", slot.cache_tokens.size()}, + {"truncated", slot.truncated} + }); + } + + // process in chunks of params.n_batch + int32_t n_batch = llama_n_batch(ctx); + int32_t n_ubatch = llama_n_ubatch(ctx); + + // track if this is an embedding or non-embedding batch + // if we've added sampled tokens above, we are in non-embedding mode + // -1: none, 0: non-embedding, 1: embedding + int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; + + // next, batch any pending prompts without exceeding n_batch + if (params.cont_batching || batch.n_tokens == 0) { + for (auto& slot : slots) { + // this slot still has a prompt to be processed + if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) { + auto& prompt_tokens = slot.prompt_tokens; + + // we haven't tokenized the prompt yet - do it now: + if (prompt_tokens.empty() || slot.n_prompt_tokens == 0) { + LOG_VERBOSE("tokenizing prompt", { + {"id_slot", slot.id}, + {"id_task", slot.id_task} + }); + + slot.t_start_process_prompt = ggml_time_us(); + slot.t_start_generation = 0; + + if (slot.infill) { + const bool add_bos = llama_should_add_bos_token(model); + bool suff_rm_leading_spc = true; + if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } + + auto prefix_tokens = tokenize(slot.params.input_prefix, false); + auto suffix_tokens = tokenize(slot.params.input_suffix, false); + + const int space_token = 29871; // TODO: this should not be hardcoded + if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) { + suffix_tokens.erase(suffix_tokens.begin()); + } + + prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); + suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model)); + + auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens; + auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens; + if (add_bos) { + embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); + } + embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); + + const llama_token middle_token = llama_token_middle(model); + if (middle_token >= 0) { + embd_inp.push_back(middle_token); + } + + prompt_tokens = server_tokens(embd_inp, false); + } + else { + // prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt + } + + slot.n_past = 0; + slot.n_prompt_tokens = prompt_tokens.size(); + + LOG_VERBOSE("prompt tokenized", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"n_ctx", slot.n_ctx}, + {"n_keep", slot.params.n_keep}, + {"n_prompt_tokens", slot.n_prompt_tokens}, + {"prompt_tokens", prompt_tokens.detokenize(ctx, true)}, + }); + + // empty prompt passed -> release the slot and send empty response + if (prompt_tokens.empty()) { + LOG_INFO("empty prompt - releasing slot", { + {"id_slot", slot.id}, + {"id_task", slot.id_task} + }); + + slot.state = SLOT_STATE_PROCESSING; + slot.command = SLOT_COMMAND_NONE; + slot.release(); + slot.print_timings(); + send_final_response(slot); + continue; + } + + if (slot.embedding) { + // this prompt is too large to process - discard it + if (slot.n_prompt_tokens > n_ubatch) { + slot.state = SLOT_STATE_PROCESSING; + slot.command = SLOT_COMMAND_NONE; + slot.release(); + send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); + continue; + } + } + else { + // if input prompt is too big, truncate it (if group attention self-extend is disabled) + // context shift for prompt processing + if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) { + if (!params.ctx_shift) { + send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + if (mctx) { + // we should never reach this because params.ctx_shift is automatically disabled if mmproj is loaded + // we don't support ctx_shift because an image chunk may contains multiple tokens + GGML_ABORT("not supported by multimodal"); + } + + context_shift_prompt(ctx, slot); + slot.truncated = true; + LOG_VERBOSE("input truncated", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"n_ctx", slot.n_ctx}, + {"n_keep", slot.params.n_keep}, + {"n_left", slot.n_ctx - slot.params.n_keep}, + {"n_prompt_tokens", slot.n_prompt_tokens}, + {"prompt_tokens", prompt_tokens.detokenize(ctx, true)}, + }); + + GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); + +#ifndef NDEBUG + // debug + common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false); + int32_t back = 1; + if (slot.cache_tokens.size() && slot.cache_tokens.size() > prefix.first + 20 + && prefix.second >= back && prefix.first >= back) { + LLAMA_LOG_INFO("After context shift :\n"); + print_tokens(slot.prompt_tokens, slot.cache_tokens, prefix.second - back, prefix.first - back, 50); + } +#endif + } + else { + slot.n_discarded_prompt = 0; + } + llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling); + + if (!slot.params.cache_prompt) { + slot.n_past_se = 0; + slot.ga_i = 0; + } + else { + GGML_ASSERT(slot.ga_n == 1); + + // reuse any previously computed tokens that are common with the new prompt + common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, true); // string level match + common_prefix prefix_nonexact = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false); + auto n_past0 = slot.cache_tokens.get_common_prefix_exact(prompt_tokens); // token level match + LLAMA_LOG_INFO("======== Cache: cache_size = %d, n_past0 = %d, n_past1 = %d, n_past_prompt1 = %d, n_past2 = %d, n_past_prompt2 = %d\n", (int32_t)slot.cache_tokens.size(), (int32_t)n_past0, (int32_t)prefix.first, (int32_t)prefix.second, (int32_t)prefix_nonexact.first, (int32_t)prefix_nonexact.second); + int32_t size_threshold = 20; + if (prefix.first + size_threshold < prefix_nonexact.first) { + LLAMA_LOG_WARN("Common part contains missing or extra space and new line\n"); + prefix = prefix_nonexact; + } + slot.n_past = prefix.first; + slot.n_past_prompt = prefix.second; + if (slot.n_past != slot.n_past_prompt) { + LLAMA_LOG_INFO("Mistokenization found and handled successfully.\n"); + } + if ((slot.n_past + size_threshold < slot.cache_tokens.size())) + { + LLAMA_LOG_WARN("Common part does not match fully\n"); + int32_t back = 4; + if (prefix.second >= back && prefix.first >= back) { + print_tokens(slot.prompt_tokens, slot.cache_tokens, prefix.second - back, prefix.first - back, 30); + } + } + + // push the prompt into the sampling context (do not apply grammar) + for (int i = 0; i < slot.n_past; ++i) { + llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); + } + } + } + + if (slot.n_past_prompt == slot.n_prompt_tokens && slot.n_past_prompt > 0) { + // we have to evaluate at least 1 token to generate logits. + LOG_INFO("we have to evaluate at least 1 token to generate logits", { + { "id_slot", slot.id }, + { "id_task", slot.id_task } + }); + + slot.n_past_prompt--; + slot.n_past--; + if (slot.ga_i > 0) { + slot.n_past_se--; + } + } + + slot.n_prompt_tokens_processed = 0; + } + + if (slot.embedding) { + // cannot fit the prompt in the current batch - will try next iter + if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + continue; + } + } + + // check that we are in the right batch_type, if not defer the slot + bool slot_type = slot.embedding ? 1 : 0; + if (batch_type == -1) { + batch_type = slot_type; + } + else if (batch_type != slot_type) { + continue; + } + + // keep only the common part + // remove the non-common part from the cache + slot.cache_tokens.keep_first(slot.n_past); + int p0 = (int)system_tokens.size() + slot.n_past; + p0 = system_tokens.size() + slot.cache_tokens.pos_next(); + if (!llama_kv_cache_seq_rm(ctx, slot.id, p0, -1)) { + // could not partially delete (likely using a non-Transformer model) + llama_kv_cache_seq_rm(ctx, slot.id, -1, -1); + + p0 = (int)system_tokens.size(); + if (p0 != 0) { + // copy over the system prompt when there is one + llama_kv_cache_seq_cp(ctx, 0, slot.id, -1, -1); + } + + // there is no common part left (except for the system prompt) + slot.n_past = 0; + slot.n_past_se = 0; + slot.ga_i = 0; + // TODO: is the system prompt ever in the sampling context? + llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling); + } + + LOG_INFO("kv cache rm [p0, end)", { + { "id_slot", slot.id }, + { "id_task", slot.id_task }, + { "p0", p0 } + }); + + // check if we should process the image + if (slot.n_past_prompt < slot.n_prompt_tokens + && slot.prompt_tokens[slot.n_past_prompt] == LLAMA_TOKEN_NULL) { + // process the image + size_t n_tokens_out = 0; + llama_pos p1 = slot.cache_tokens.pos_next() + slot.n_past_prompt - slot.n_past; // add offset to prompt + int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past_prompt, p1, slot.id, n_tokens_out); + if (res != 0) { + LLAMA_LOG_ERROR("failed to process image, res = %d\n", res); + slot.release(); + send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + continue; + } + + // add the image chunk to cache + { + const auto& chunk = slot.prompt_tokens.find_chunk(slot.n_past_prompt); + slot.cache_tokens.push_back(chunk.get()); // copy + } + + slot.n_past += n_tokens_out; + slot.n_past_prompt += n_tokens_out; + slot.n_prompt_tokens_processed += n_tokens_out; + + } + + + + int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; + + int32_t ga_i = slot.ga_i; + int32_t ga_n = slot.ga_n; + int32_t ga_w = slot.ga_w; + + // add prompt tokens for processing in the current batch + // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow + while (slot.n_past_prompt < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + // get next token to process + llama_token cur_tok = slot.prompt_tokens[slot.n_past_prompt]; + if (cur_tok == LLAMA_TOKEN_NULL) { + break; // end of text chunk + } + if (slot.ga_n != 1) { + while (slot_npast >= ga_i + ga_w) { + const int bd = (ga_w / ga_n) * (ga_n - 1); + slot_npast -= bd; + ga_i += ga_w / ga_n; + } + } + + int p0 = system_tokens.size() + slot.cache_tokens.pos_next(); + llama_batch_add(batch, cur_tok, p0, { slot.id }, false); + + slot.cache_tokens.push_back(cur_tok); + + + slot.n_prompt_tokens_processed++; + slot_npast++; + slot.n_past_prompt++; + slot.n_past++; + } + LOG_VERBOSE("prompt processing progress", { + {"id_slot", slot.id}, + {"n_past", slot.n_past}, + {"n_ctx", n_ctx}, + {"n_tokens", batch.n_tokens}, + {"progress", (float)slot.n_prompt_tokens_processed / slot.n_prompt_tokens}, + }); + + // entire prompt has been processed - start decoding new tokens + if (slot.n_past_prompt == slot.n_prompt_tokens) { + slot.state = SLOT_STATE_PROCESSING; + slot.command = SLOT_COMMAND_NONE; + + GGML_ASSERT(batch.n_tokens > 0); + GGML_ASSERT((size_t)slot.n_prompt_tokens == slot.prompt_tokens.size()); + llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling); + for (int i = 0; i < slot.n_prompt_tokens; ++i) { + llama_token id = slot.prompt_tokens[i]; + if (id != LLAMA_TOKEN_NULL) { + llama_sampling_accept(slot.ctx_sampling, ctx, id, false); + } + } + + // extract the logits only for the last token + batch.logits[batch.n_tokens - 1] = true; + + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; + + LOG_VERBOSE("prompt done", { + {"id_slot", slot.id}, + {"n_past", slot.n_past}, + {"n_ctx", n_ctx}, + {"n_tokens", batch.n_tokens}, + }); + } + } + + if (batch.n_tokens >= n_batch) { + break; + } + } + } + + if (batch.n_tokens == 0) { + LOG_VERBOSE("no tokens to decode", {}); + return; + } + + LOG_VERBOSE("decoding batch", { + {"n_tokens", batch.n_tokens}, + }); + + // make sure we're in the right embedding mode + llama_set_embeddings(ctx, batch_type == 1); + + // process the created batch of tokens + for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); + + for (auto& slot : slots) { + if (slot.ga_n != 1) { + // context extension via Self-Extend + // TODO: simplify and/or abstract this + while (slot.n_past_se >= slot.ga_i + slot.ga_w) { + const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; + const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); + const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; + + LOG_TEE("\n"); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd); + LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); + + llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); + llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); + llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); + + slot.n_past_se -= bd; + + slot.ga_i += slot.ga_w / slot.ga_n; + + LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); + } + + slot.n_past_se += n_tokens; + } + } + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + 0, 0, 0, // unused + }; + + const int ret = llama_decode(ctx, batch_view); + + if (ret != 0) { + if (n_batch == 1 || ret < 0) { + // if you get here, it means the KV cache is full - try increasing it via the context size + LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", { + {"i", i}, + {"n_batch", ret}, + {"ret", ret}, + }); + for (auto& slot : slots) { + slot.state = SLOT_STATE_PROCESSING; + slot.command = SLOT_COMMAND_NONE; + slot.release(); + LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size()); + send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); + } + break; // break loop of n_batch + } + + + // retry with half the batch size to try to find a free slot in the KV cache + n_batch /= 2; + i -= n_batch; + + LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", { + {"i", i}, + {"n_batch", n_batch}, + {"ret", ret}, + }); + + continue; // continue loop of n_batch + } + + for (auto& slot : slots) { + if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) { + continue; // continue loop of slots + } + + // prompt evaluated for embedding + if (slot.embedding) { + send_embedding(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + completion_token_output result; + const int tok_idx = slot.i_batch - i; + const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, tok_idx); + + llama_sampling_accept(slot.ctx_sampling, ctx, id, true); + + slot.n_decoded += 1; + + const int64_t t_current = ggml_time_us(); + + if (slot.n_decoded == 1) { + slot.t_start_generation = ggml_time_us(); + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } + + slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; + + result.tok = id; + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs + result.text_to_send = llama_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + + if (slot.sparams.n_probs > 0) { + populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, tok_idx); + } + + if (!process_token(result, slot)) { + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + } + + slot.i_batch = -1; + } + + // Do speculative decoding + for (auto& slot : slots) { + if (!slot.is_processing() || !slot.spec) { + continue; + } + + if (slot.state != SLOT_STATE_PROCESSING) { + continue; + } + + if (mctx) { + // we should never reach this, as speculative is automatically disabled if mmproj is loaded + GGML_ABORT("not supported by multimodal"); + } + + // determine the max draft that fits the current slot state + int n_draft_max = slot.params.speculative.n_max; + + // note: n_past is not yet increased for the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); + + if (slot.n_predict > 0) { + n_draft_max = std::min(n_draft_max, slot.n_predict - slot.n_decoded - 1); + } + + LOG_VERBOSE("max possible draft", { + {"id_slot", slot.id}, + {"n_draft_max", n_draft_max} + }); + + if (n_draft_max < slot.params.speculative.n_min) { + LOG_VERBOSE("the max possible draft is too small", { + {"id_slot", slot.id}, + {"n_draft_max", n_draft_max}, + {"n_min", slot.params.speculative.n_min} + }); + continue; + } + + llama_token id = slot.sampled; + + struct llama_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = cparams_dft.n_ctx - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; + + const std::vector& cached_text_tokens = slot.cache_tokens.tokens_data(); + std::vector draft = llama_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + + // ignore small drafts + if (slot.params.speculative.n_min > (int)draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min); + continue; + } + + // keep track of total number of drafted tokens tested + slot.n_draft_total += draft.size(); + + // construct the speculation batch + llama_batch_clear(slot.batch_spec); + llama_batch_add(slot.batch_spec, id, slot.cache_tokens.pos_next(), { slot.id }, true); + + for (size_t i = 0; i < draft.size(); ++i) { + llama_batch_add(slot.batch_spec, draft[i], slot.cache_tokens.pos_next() + 1 + i, { slot.id }, true); + } + + LOG_VERBOSE("decoding speculative batch", { + {"id_slot", slot.id}, + {"size", slot.batch_spec.n_tokens} + }); + + llama_decode(ctx, slot.batch_spec); + + // the accepted tokens from the speculation + std::vector ids = llama_sampling_sample_and_accept_n(slot.ctx_sampling, ctx, draft); + + slot.n_past += ids.size(); + slot.n_decoded += ids.size(); + + // update how many tokens out of those tested were accepted + slot.n_draft_accepted += ids.size() - 1; + + slot.cache_tokens.push_back(id); + slot.cache_tokens.insert({ ids.begin(), ids.end() - 1 }); + + llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = llama_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // set later + + if (slot.sparams.n_probs > 0) { + populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, i); + } + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + break; + } + } + + LOG_VERBOSE("speculative decoding result", { + {"id_slot", slot.id}, + {"accepted", (int)ids.size() - 1}, + {"total", (int)draft.size()}, + {"new_n_past", slot.n_past} + }); + } + } + + LOG_VERBOSE("run slots completed", {}); +} + +json server_context::model_meta() const { + return json{ + {"vocab_type", llama_vocab_type(model)}, + {"n_vocab", llama_n_vocab(model)}, + {"n_ctx_train", llama_n_ctx_train(model)}, + {"n_embd", llama_n_embd(model)}, + {"n_params", llama_model_n_params(model)}, + {"size", llama_model_size(model)}, + }; +} diff --git a/examples/server/server-context.h b/examples/server/server-context.h new file mode 100644 index 00000000..568db95a --- /dev/null +++ b/examples/server/server-context.h @@ -0,0 +1,316 @@ +#include "server-task.h" +#include "server-queue.h" +#include "speculative.h" +#include "json-schema-to-grammar.h" +#include + +#include +#include + + + +enum slot_state { + SLOT_STATE_IDLE, + SLOT_STATE_PROCESSING, +}; + + + +enum slot_command { + SLOT_COMMAND_NONE, + SLOT_COMMAND_LOAD_PROMPT, + SLOT_COMMAND_RELEASE, +}; + + + + +struct slot_params { + bool stream = true; + bool include_usage = false; + bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half + int32_t n_predict = -1; // new tokens to predict + + std::vector antiprompt; + + bool timings_per_token = false; + bool post_sampling_probs = false; + json input_prefix; + json input_suffix; + + // speculative decoding parameters + struct { + int n_max = 16; // max drafted tokens + int n_min = 0; // min drafted tokens to accept + float p_min = 0.75f; // min probability required to accept a token in the draft + } speculative; + + // OAI-compat fields + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_syntax oaicompat_chat_syntax; + +}; + + +struct server_slot { + int id; + int id_task = -1; + int id_multi = -1; + + struct slot_params params; + + slot_state state = SLOT_STATE_IDLE; + slot_command command = SLOT_COMMAND_NONE; + + llama_context* ctx = nullptr; + // used to determine the slot that has been used the longest + int64_t t_last_used = -1; + + std::unique_ptr task; + + // generation props + int32_t n_ctx = 0; // context size per slot + int32_t n_past = 0; + int32_t n_past_prompt = 0; + int32_t n_decoded = 0; + int32_t n_remaining = -1; + int32_t n_discarded_prompt = 0; + int32_t n_kept_prompt = 0; + + int32_t i_batch = -1; + int32_t n_predict = -1; // TODO: disambiguate from params.n_predict + + int32_t n_prompt_tokens = 0; + int32_t n_prompt_tokens_processed = 0; + + json prompt; // can be either a string, array of strings or array of token ids + + // when a task is submitted, we first tokenize the prompt and store it here + server_tokens prompt_tokens; + server_tokens cache_tokens; + + std::string generated_text; + + std::vector generated_token_probs; + common_chat_msg chat_msg; + + bool infill = false; + bool embedding = false; + bool has_next_token = true; + bool truncated = false; + bool stopped_eos = false; + bool stopped_word = false; + bool stopped_limit = false; + + bool oaicompat = false; + + std::string oaicompat_model; + std::string stopping_word; + stop_type stop; + + server_prompt server_cached_prompt; + + void prompt_save(server_prompt_cache& prompt_cache) const; + + void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens); + + // sampling + llama_token sampled; + struct llama_sampling_params sparams; + llama_sampling_context* ctx_sampling = nullptr; + json json_schema; + + common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + std::vector generated_tool_call_ids; + + int32_t ga_i = 0; // group-attention state + int32_t ga_n = 1; // group-attention factor + int32_t ga_w = 512; // group-attention width + + // multimodal + mtmd_context* mctx = nullptr; + + // speculative decoding + struct llama_speculative* spec = nullptr; + llama_context* ctx_dft = nullptr; + llama_batch batch_spec = {}; + + // speculative decoding stats + int32_t n_draft_total = 0; // Total draft tokens generated + int32_t n_draft_accepted = 0; // Draft tokens actually accepted + + int32_t n_past_se = 0; // self-extend + + // stats + size_t n_sent_text = 0; // number of sent text character + size_t n_sent_token_probs = 0; + + int64_t t_start_process_prompt; + int64_t t_start_generation; + + double t_prompt_processing; // ms + double t_token_generation; // ms + + void reset(); + + bool has_budget(gpt_params& global_params); + + bool available() const; + + bool is_processing() const; + + void add_token_string(const completion_token_output& token); + + void release(); + + json get_formated_timings() const; + + result_timings get_timings() const; + + const common_chat_msg& update_chat_msg(std::vector& diffs); + + size_t find_stopping_strings(const std::string& text, const size_t last_token_size, bool is_full_stop); + + void print_timings() const; +}; + +struct server_metrics { + int64_t t_start = 0; + + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + void init(); + + void on_prompt_eval(const server_slot& slot); + + void on_prediction(const server_slot& slot); + + void reset_bucket(); +}; + +struct server_context { + llama_model* model = nullptr; + llama_context* ctx = nullptr; + std::vector lora_adapters; + + gpt_params params; + + llama_batch batch; + + bool clean_kv_cache = true; + bool add_bos_token = true; + bool has_eos_token = false; + + // multimodal + mtmd_context* mctx = nullptr; + + // For speculative decoding + llama_model* model_draft = nullptr; + llama_context* ctx_draft = nullptr; + llama_context_params cparams_dft; + + int32_t n_ctx; // total context for all clients / slots + + // system prompt + bool system_need_update = false; + + std::string system_prompt; + std::vector system_tokens; + + // slots / clients + std::vector slots; + json default_generation_settings_for_props; + + server_queue queue_tasks; + server_response queue_results; + + std::unique_ptr prompt_cache; + + server_metrics metrics; + + common_chat_templates_ptr chat_templates; + oaicompat_parser_options oai_parser_opt; + // Necessary similarity of prompt for slot selection + float slot_prompt_similarity = 0.0f; + int32_t cache_ram_n_min = 0; + float cache_ram_similarity = 0.5f; + + ~server_context(); + + bool load_model(const gpt_params& params_); + + void init(); + + std::vector tokenize(const json& json_prompt, bool add_special) const; + + server_slot* get_slot_by_id(int id); + + server_slot* get_available_slot(const server_task& task); + + bool launch_slot_with_task(server_slot& slot, server_task& task); + + void kv_cache_clear(); + + void system_prompt_update(); + + bool system_prompt_set(const std::string& sys_prompt); + + bool process_token(completion_token_output& result, server_slot& slot); + + void populate_token_probs(const server_slot& slot, completion_token_output& result, bool post_sampling, bool special, int idx); + + json get_formated_generation(const server_slot& slot) const; + + void send_error(const server_task& task, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER); + + void send_error(const server_slot& slot, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER); + + void send_error(const int id_task, const int id_multi, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER); + + // if multimodal is enabled, send an error and return false + bool ensure_no_mtmd(const int id_task); + + void send_partial_response(server_slot& slot, completion_token_output tkn); + + void send_final_response(server_slot& slot); + + void send_embedding(const server_slot& slot, const llama_batch& batch); + + void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs); + + void request_cancel(int id_task); + + void split_multiprompt_task(int id_multi, server_task& multiprompt_task); + + void process_single_task(server_task&& task); + + void on_finish_multitask(const server_task_multi& multitask); + + void print_tokens(const server_tokens& prompt, const server_tokens& cache, size_t start1 = 0, size_t start2 = 0, size_t length = 10); + + void discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard); + + // convert keep first few and discard next tokens in a to b + void context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep, + int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact = false); + + void context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact = false); + + void update_slots(); + + json model_meta() const; +}; diff --git a/examples/server/server-queue.cpp b/examples/server/server-queue.cpp new file mode 100644 index 00000000..c9a75223 --- /dev/null +++ b/examples/server/server-queue.cpp @@ -0,0 +1,194 @@ +#include "server-task.h" +#include "server-queue.h" +#include "server-common.h" + +#include "log.h" +#include + +#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +#define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) + + +int server_queue::post(server_task task) { + std::unique_lock lock(mutex_tasks); + if (task.id == -1) { + task.id = id++; + //LOG_VERBOSE("new task id", { {"new_id", task.id} }); + QUE_DBG("new task, id = %d\n", task.id); + } + queue_tasks.push_back(std::move(task)); + condition_tasks.notify_one(); + return task.id; +} + +void server_queue::defer(server_task&& task) { + std::unique_lock lock(mutex_tasks); + queue_tasks_deferred.push_back(std::move(task)); +} + +int server_queue::get_new_id() { + std::unique_lock lock(mutex_tasks); + int new_id = id++; + //LOG_VERBOSE("new task id", { {"new_id", new_id} }); + QUE_DBG("new task, id = %d\n", id); + return new_id; +} + +void server_queue::notify_slot_changed() { + // move deferred tasks back to main loop + std::unique_lock lock(mutex_tasks); + for (auto& task : queue_tasks_deferred) { + queue_tasks.push_back(std::move(task)); + } + queue_tasks_deferred.clear(); +} + +void server_queue::on_new_task(std::function callback) { + callback_new_task = std::move(callback); +} + + +void server_queue::start_loop() { + running = true; + + while (true) { + LOG_VERBOSE("new task may arrive", {}); + + while (true) { + std::unique_lock lock(mutex_tasks); + if (queue_tasks.empty()) { + lock.unlock(); + break; + } + server_task task = std::move(queue_tasks.front()); + queue_tasks.erase(queue_tasks.begin()); + lock.unlock(); + //LOG_VERBOSE("callback_new_task", { {"id_task", task.id} }); + callback_new_task(std::move(task)); + } + + LOG_VERBOSE("update_multitasks", {}); + + // check if we have any finished multitasks + auto queue_iterator = queue_multitasks.begin(); + while (queue_iterator != queue_multitasks.end()) { + if (queue_iterator->subtasks_remaining.empty()) { + // all subtasks done == multitask is done + server_task_multi current_multitask = *queue_iterator; + callback_finish_multitask(current_multitask); + // remove this multitask + queue_iterator = queue_multitasks.erase(queue_iterator); + } + else { + ++queue_iterator; + } + } + + // all tasks in the current loop is processed, slots data is now ready + LOG_VERBOSE("callback_update_slots", {}); + + callback_update_slots(); + + LOG_VERBOSE("wait for new task", {}); + { + std::unique_lock lock(mutex_tasks); + if (queue_tasks.empty()) { + if (!running) { + LOG_VERBOSE("ending start_loop", {}); + return; + } + condition_tasks.wait(lock, [&] { + return (!queue_tasks.empty() || !running); + }); + } + } + } +} + + +void server_queue::add_multitask(int id_multi, std::vector& sub_ids) { + std::lock_guard lock(mutex_tasks); + server_task_multi multi; + multi.id = id_multi; + std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); + queue_multitasks.push_back(multi); +} + + +void server_queue::update_multitask(int id_multi, int id_sub, server_task_result& result) { + std::lock_guard lock(mutex_tasks); + for (auto& multitask : queue_multitasks) { + if (multitask.id == id_multi) { + multitask.subtasks_remaining.erase(id_sub); + multitask.results.push_back(result); + } + } +} + + +void server_response::add_waiting_task_id(int id_task) { + //LOG_VERBOSE("waiting for task id", { {"id_task", id_task} }); + QUE_DBG("waiting for task id, id = %d\n", id_task); + + std::unique_lock lock(mutex_results); + waiting_task_ids.insert(id_task); +} + +void server_response::remove_waiting_task_id(int id_task) { + //LOG_VERBOSE("remove waiting for task id", { {"id_task", id_task} }); + QUE_DBG("remove waiting for task id, id = %d\n", id_task); + std::unique_lock lock(mutex_results); + waiting_task_ids.erase(id_task); +} + + +server_task_result server_response::recv(int id_task) { + while (true) { + std::unique_lock lock(mutex_results); + condition_results.wait(lock, [&] { + return !queue_results.empty(); + }); + + for (int i = 0; i < (int)queue_results.size(); i++) { + if (queue_results[i].id == id_task) { + assert(queue_results[i].id_multi == -1); + server_task_result res = queue_results[i]; + queue_results.erase(queue_results.begin() + i); + return res; + } + } + } + + // should never reach here +} + +void server_response::send(server_task_result result) { + //LOG_VERBOSE("send new result", { {"id_task", result.id} }); + QUE_DBG("send new result, id = %d\n", result.id); + std::unique_lock lock(mutex_results); + for (const auto& id_task : waiting_task_ids) { + // LOG_TEE("waiting task id %i \n", id_task); + // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result + if (result.id_multi == id_task) { + //LOG_VERBOSE("callback_update_multitask", { {"id_task", id_task} }); + QUE_DBG("callback_update_multitask, id = %d\n", id_task); + callback_update_multitask(id_task, result.id, result); + continue; + } + + if (result.id == id_task) { + //LOG_VERBOSE("queue_results.push_back", { {"id_task", id_task} }); + QUE_DBG("queue_results.push_back, id = %d\n", id_task); + queue_results.push_back(result); + condition_results.notify_all(); + return; + } + } +} diff --git a/examples/server/server-queue.h b/examples/server/server-queue.h new file mode 100644 index 00000000..cadff28a --- /dev/null +++ b/examples/server/server-queue.h @@ -0,0 +1,117 @@ +#pragma once +#include "server-task.h" + +#include +#include +#include +#include + +struct server_task_multi { + int id = -1; + + std::set subtasks_remaining; + std::vector results; +}; + + +struct server_queue { + int id = 0; + bool running; + + // queues + std::vector queue_tasks; + std::vector queue_tasks_deferred; + + std::vector queue_multitasks; + + std::mutex mutex_tasks; + std::condition_variable condition_tasks; + + // callback functions + std::function callback_new_task; + std::function callback_finish_multitask; + std::function callback_update_slots; + + + // Add a new task to the end of the queue + int post(server_task task); + + // Add a new task, but defer until one slot is available + void defer(server_task&& task); + + // Get the next id for creating anew task + int get_new_id(); + + // Register function to process a new task + void on_new_task(std::function callback); + + // Register function to process a multitask when it is finished + void on_finish_multitask(std::function callback) { + callback_finish_multitask = std::move(callback); + } + + // Register the function to be called when all slots data is ready to be processed + void on_update_slots(std::function callback) { + callback_update_slots = std::move(callback); + } + + // Call when the state of one slot is changed + void notify_slot_changed(); + + // end the start_loop routine + void terminate() { + std::unique_lock lock(mutex_tasks); + running = false; + condition_tasks.notify_all(); + } + + /** + * Main loop consists of these steps: + * - Wait until a new task arrives + * - Process the task (i.e. maybe copy data into slot) + * - Check if multitask is finished + * - Update all slots + */ + void start_loop(); + + // + // functions to manage multitasks + // + + // add a multitask by specifying the id of all subtask (subtask is a server_task) + void add_multitask(int id_multi, std::vector& sub_ids); + + // updatethe remaining subtasks, while appending results to multitask + void update_multitask(int id_multi, int id_sub, server_task_result& result); +}; + +struct server_response { + typedef std::function callback_multitask_t; + callback_multitask_t callback_update_multitask; + + // for keeping track of all tasks waiting for the result + std::set waiting_task_ids; + + // the main result queue + std::vector queue_results; + + std::mutex mutex_results; + std::condition_variable condition_results; + + // add the id_task to the list of tasks waiting for response + void add_waiting_task_id(int id_task); + + // when the request is finished, we can remove task associated with it + void remove_waiting_task_id(int id_task); + + // This function blocks the thread until there is a response for this id_task + server_task_result recv(int id_task); + + // Register the function to update multitask + void on_multitask_update(callback_multitask_t callback) { + callback_update_multitask = std::move(callback); + } + + // Send a new result to a waiting id_task + void send(server_task_result result); +}; diff --git a/examples/server/server-task.cpp b/examples/server/server-task.cpp new file mode 100644 index 00000000..335dc85c --- /dev/null +++ b/examples/server/server-task.cpp @@ -0,0 +1,816 @@ +#include "server-task.h" + + +json result_timings::to_json() const { + json base = { + {"prompt_n", prompt_n}, + {"prompt_ms", prompt_ms}, + {"prompt_per_token_ms", prompt_per_token_ms}, + {"prompt_per_second", prompt_per_second}, + + {"predicted_n", predicted_n}, + {"predicted_ms", predicted_ms}, + {"predicted_per_token_ms", predicted_per_token_ms}, + {"predicted_per_second", predicted_per_second}, + + {"n_ctx", n_ctx}, + {"n_past", n_past}, + }; + + if (draft_n > 0) { + base["draft_n"] = draft_n; + base["draft_n_accepted"] = draft_n_accepted; + } + + return base; +} + + +json server_task_result::to_json_final() { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat_final(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat_final(); + case OAICOMPAT_TYPE_CHAT: + return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat_final(); + case OAICOMPAT_TYPE_ANTHROPIC: + return stream ? to_json_anthropic_stream() : to_json_anthropic_final(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } +} + +json server_task_result::to_json_partial() { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat_partial(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat_partial(); + case OAICOMPAT_TYPE_CHAT: + return to_json_oaicompat_chat_partial(); + case OAICOMPAT_TYPE_ANTHROPIC: + return to_json_anthropic_partial(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } +} + +json server_task_result::to_json_non_oaicompat_partial() { + // non-OAI-compat JSON + json res = json{ + {"index", index}, + {"content", content}, + {"tokens", tokens}, + {"stop", false}, + {"id_slot", id_multi}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + }; + // populate the timings object when needed (usually for the last response or with timings_per_token enabled) + if (timings.prompt_n > 0) { + res.push_back({ "timings", timings.to_json() }); + } + if (!probs_output.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); + } + return res; +} + +json server_task_result::to_json_non_oaicompat_final() { + json res = json{ + {"index", index}, + {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"tokens", stream ? std::vector {} : tokens}, + {"id_slot", id_multi}, + {"stop", true}, + {"model", oaicompat_model}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + //{"generation_settings", default_generation_settings_for_props.to_json()}, + {"prompt", prompt}, + {"has_new_line", has_new_line}, + {"truncated", truncated}, + //{"stop_type", stop_type_to_str(STOP_TYPE_EOS)}, + {"stopping_word", stopping_word}, + {"tokens_cached", n_tokens_cached}, + {"timings", timings.to_json()}, + }; + if (!stream && !probs_output.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); + } + return response_fields.empty() ? res : json_get_nested_values(response_fields, res); +} + +json server_task_result::to_json_oaicompat_partial() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (probs_output.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + json res = json{ + {"choices", json::array({ + json{ + {"text", content}, + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", nullptr}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"object", "text_completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat_partial(); + } + if (timings.prompt_n >= 0) { + res.push_back({ "timings", timings.to_json() }); + } + + return res; +} + +json server_task_result::to_json_oaicompat_final() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (!stream && probs_output.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + json finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + json res = json{ + {"choices", json::array({ + json{ + {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", finish_reason}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"object", "text_completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat_final(); + } + if (timings.prompt_n >= 0) { + res.push_back({ "timings", timings.to_json() }); + } + + return res; +} + +json server_task_result::to_json_oaicompat_chat_partial() { + bool first = n_decoded == 1; + std::time_t t = std::time(0); + json choices; + + std::vector deltas; + auto add_delta = [&](const json& delta) { + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", delta}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}, + }); + }; + // We have to send an initial update to conform to openai behavior + if (first) { + add_delta({ + {"role", "assistant"}, + {"content", nullptr}, + }); + } + + for (const auto& diff : oaicompat_msg_diffs) { + add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); + } + + if (!deltas.empty()) { + GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1); + + if (probs_output.size() > 0) { + deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + + if (timings.prompt_n >= 0) { + deltas[deltas.size() - 1].push_back({ "timings", timings.to_json() }); + } + } + + return deltas; +} + +json server_task_result::to_json_oaicompat_chat_final() { + std::string finish_reason = "length"; + common_chat_msg msg; + if (!oaicompat_msg.empty()) { + msg = oaicompat_msg; + } + else { + msg.role = "assistant"; + msg.content = content; + } + if (stop) { + finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; + } + + + json choice{ + {"finish_reason", finish_reason}, + {"index", 0}, + {"message", msg.to_json_oaicompat()}, + }; + + if (!stream && probs_output.size() > 0) { + choice["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + + std::time_t t = std::time(0); + + json res = json{ + {"choices", json::array({choice})}, + {"created", t}, + {"model", oaicompat_model}, + {"object", "chat.completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat_final(); + } + if (timings.prompt_n >= 0) { + res.push_back({ "timings", timings.to_json() }); + } + + return res; +} + +json server_task_result::to_json_oaicompat_chat_stream() { + std::time_t t = std::time(0); + std::string finish_reason = "length"; + if (stop) { + //if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls"; + } + + json deltas = json::array(); + for (const auto& diff : oaicompat_msg_diffs) { + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}, + }); + } + + deltas.push_back({ + {"choices", json::array({ + json { + {"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}, + }); + if (include_usage) { + // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage + // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices + deltas.push_back({ + {"choices", json::array()}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}, + }); + } + if (timings.prompt_n >= 0) { + deltas.back().push_back({ "timings", timings.to_json() }); + } + // extra fields for debugging purposes + if (verbose && !deltas.empty()) { + deltas.front()["__verbose"] = to_json_non_oaicompat_final(); + } + + return deltas; +} + +json server_task_result::to_json_anthropic_final() { + std::string stop_reason = "max_tokens"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; + } + + json content_blocks = json::array(); + + common_chat_msg msg; + if (!oaicompat_msg.empty()) { + msg = oaicompat_msg; + } + else { + msg.role = "assistant"; + msg.content = content; + } + + + if (!msg.content.empty()) { + content_blocks.push_back({ + {"type", "text"}, + {"text", msg.content} + }); + } + + for (const auto& tool_call : msg.tool_calls) { + json tool_use_block = { + {"type", "tool_use"}, + {"id", tool_call.id}, + {"name", tool_call.name} + }; + + try { + tool_use_block["input"] = json::parse(tool_call.arguments); + } + catch (const std::exception&) { + tool_use_block["input"] = json::object(); + } + + content_blocks.push_back(tool_use_block); + } + + json res = { + {"id", oaicompat_cmpl_id}, + {"type", "message"}, + {"role", "assistant"}, + {"content", content_blocks}, + {"model", oaicompat_model}, + {"stop_reason", stop_reason}, + {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}, + {"usage", { + {"input_tokens", n_prompt_tokens}, + {"output_tokens", n_decoded} + }} + }; + + return res; +} + +json server_task_result::to_json_anthropic_stream() { + json events = json::array(); + + std::string stop_reason = "max_tokens"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; + } + + bool has_text = !oaicompat_msg.content.empty(); + size_t num_tool_calls = oaicompat_msg.tool_calls.size(); + + bool text_block_started = false; + std::set tool_calls_started; + + for (const auto& diff : oaicompat_msg_diffs) { + + if (!diff.content_delta.empty()) { + if (!text_block_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", 0}, + {"content_block", { + {"type", "text"}, + {"text", ""} + }} + }} + }); + text_block_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", 0}, + {"delta", { + {"type", "text_delta"}, + {"text", diff.content_delta} + }} + }} + }); + } + + if (diff.tool_call_index != std::string::npos) { + size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index; + + if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) { + const auto& full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index]; + + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", content_block_index}, + {"content_block", { + {"type", "tool_use"}, + {"id", full_tool_call.id}, + {"name", full_tool_call.name} + }} + }} + }); + tool_calls_started.insert(diff.tool_call_index); + } + + if (!diff.tool_call_delta.arguments.empty()) { + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", content_block_index}, + {"delta", { + {"type", "input_json_delta"}, + {"partial_json", diff.tool_call_delta.arguments} + }} + }} + }); + } + } + } + + if (has_text) { + events.push_back({ + {"event", "content_block_stop"}, + {"data", { + {"type", "content_block_stop"}, + {"index", 0} + }} + }); + } + + for (size_t i = 0; i < num_tool_calls; i++) { + size_t content_block_index = (has_text ? 1 : 0) + i; + events.push_back({ + {"event", "content_block_stop"}, + {"data", { + {"type", "content_block_stop"}, + {"index", content_block_index} + }} + }); + } + + events.push_back({ + {"event", "message_delta"}, + {"data", { + {"type", "message_delta"}, + {"delta", { + {"stop_reason", stop_reason}, + {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)} + }}, + {"usage", { + {"output_tokens", n_decoded} + }} + }} + }); + + events.push_back({ + {"event", "message_stop"}, + {"data", { + {"type", "message_stop"} + }} + }); + + // extra fields for debugging purposes + if (verbose && !events.empty()) { + events.front()["data"]["__verbose"] = to_json_non_oaicompat_final(); + } + // Don't add timings for Anthropic API (breaks spec compliance) + if (oaicompat != OAICOMPAT_TYPE_ANTHROPIC && timings.prompt_n >= 0 && !events.empty()) { + events.back()["data"]["timings"] = timings.to_json(); + } + + return events; +} + +json server_task_result::to_json_anthropic_partial() { + json events = json::array(); + bool first = n_decoded == 1; + static bool text_block_started = false; + + if (first) { + text_block_started = false; + + events.push_back({ + {"event", "message_start"}, + {"data", { + {"type", "message_start"}, + {"message", { + {"id", oaicompat_cmpl_id}, + {"type", "message"}, + {"role", "assistant"}, + {"content", json::array()}, + {"model", oaicompat_model}, + {"stop_reason", nullptr}, + {"stop_sequence", nullptr}, + {"usage", { + {"input_tokens", n_prompt_tokens}, + {"output_tokens", 0} + }} + }} + }} + }); + } + + for (const auto& diff : oaicompat_msg_diffs) { + if (!diff.content_delta.empty()) { + if (!text_block_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", 0}, + {"content_block", { + {"type", "text"}, + {"text", ""} + }} + }} + }); + text_block_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", 0}, + {"delta", { + {"type", "text_delta"}, + {"text", diff.content_delta} + }} + }} + }); + } + + if (diff.tool_call_index != std::string::npos) { + size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index; + + if (!diff.tool_call_delta.name.empty()) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", content_block_index}, + {"content_block", { + {"type", "tool_use"}, + {"id", diff.tool_call_delta.id}, + {"name", diff.tool_call_delta.name} + }} + }} + }); + } + + if (!diff.tool_call_delta.arguments.empty()) { + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", content_block_index}, + {"delta", { + {"type", "input_json_delta"}, + {"partial_json", diff.tool_call_delta.arguments} + }} + }} + }); + } + } + } + + if (verbose && !events.empty() && first) { + events.front()["data"]["__verbose"] = to_json_non_oaicompat_partial(); + } + + if (timings.prompt_n >= 0 && !events.empty()) { + events.back()["data"]["timings"] = timings.to_json(); + } + + //if (is_progress && !events.empty()) { + // events.back()["data"]["prompt_progress"] = progress.to_json(); + //} + + return events; +} + + +size_t server_prompt::size() const { + size_t res = data.size(); + + for (const auto& checkpoint : checkpoints) { + res += checkpoint.size(); + } + + return res; +} + +size_t server_prompt_cache::size() const { + size_t res = 0; + + for (const auto& state : states) { + res += state.size(); + } + + return res; +} + +size_t server_prompt_cache::n_tokens() const { + size_t res = 0; + + for (const auto& state : states) { + res += state.n_tokens(); + } + return res; + +} + +bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot) { + const auto lcp_best = prompt.tokens.get_common_prefix(ctx, tokens_new); + + float f_keep_best = float(lcp_best.second) / prompt.tokens.size(); + float sim_best = prompt.tokens.get_tokens_similarity(ctx, tokens_new, prompt.n_kept_prompt, prompt.n_discarded_prompt); + LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, prompt.n_kept_prompt, prompt.n_discarded_prompt); + + auto it_best = states.end(); + + // find the most similar cached prompt, that would also preserve the most context + for (auto it = states.begin(); it != states.end(); ++it) { + const auto lcp_cur = it->tokens.get_common_prefix(ctx, tokens_new); + const float f_keep_cur = float(lcp_cur.first) / it->tokens.size(); + const float sim_cur = it->tokens.get_tokens_similarity(ctx, tokens_new, it->n_kept_prompt, it->n_discarded_prompt); + if (sim_best < sim_cur) { + f_keep_best = f_keep_cur; + sim_best = sim_cur; + it_best = it; + } + } + + if (it_best != states.end()) { + LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, it_best->n_kept_prompt, it_best->n_discarded_prompt); + const size_t size = it_best->data.size(); + const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot); + if (n != size) { + LLAMA_LOG_INFO("failed to restore state with size %zu\n", size); + return false; + } + + it_best->data.clear(); + it_best->data.shrink_to_fit(); + + prompt = std::move(*it_best); + + states.erase(it_best); + } + + return true; +} + +server_prompt* server_prompt_cache::alloc(const server_prompt& prompt, size_t state_size) { + for (auto it = states.begin(); it != states.end();) { + auto tokens_ctx_shift = server_tokens(prompt.tokens.get_text_tokens(), false); // copy cache tokens + tokens_ctx_shift.discard_n_tokens(prompt.n_kept_prompt, prompt.n_discarded_prompt); + auto prefix = it->tokens.get_common_prefix(ctx, tokens_ctx_shift); + const size_t len = prefix.first; + const size_t len_prompt = prefix.second; + // first check if the current state is contained fully in the cache + if (len_prompt == tokens_ctx_shift.size()) { + LLAMA_LOG_INFO("%s", " - prompt is already in the cache, skipping\n"); + return nullptr; + } + // next, remove any cached prompts that are fully contained in the current prompt + else if (len == it->tokens.size()) { + LLAMA_LOG_INFO(" - removing obsolete cached prompt with length %d\n", (int)len); + it = states.erase(it); + } + else { + ++it; + } + } + + std::vector state_data; + + // check if we can allocate enough memory for the new state + try { + state_data.resize(state_size); + } + catch (const std::bad_alloc& e) { + LLAMA_LOG_INFO("failed to allocate memory for prompt cache state: %s\n", e.what()); + + limit_size = std::max(1, 0.4 * size()); + + LLAMA_LOG_INFO(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0)); + + update(); + + return nullptr; + } + + // TODO: for some reason we can't copy server_tokens, so we have to do this workaround + auto& cur = states.emplace_back(); + cur = { + /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), + /*.n_keep =*/ prompt.n_kept_prompt, + /*.n_discarded_prompt =*/ prompt.n_discarded_prompt, + /*.data =*/ std::move(state_data), + /*.checkpoints =*/ prompt.checkpoints, + }; + + return &cur; +} + + +void server_prompt_cache::update() { + if (limit_size > 0) { + // always keep at least one state, regardless of the limits + while (states.size() > 1 && size() > limit_size) { + if (states.empty()) { + break; + } + + LLAMA_LOG_INFO(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + // average size per token + const float size_per_token = std::max(1.0f, float(size()) / (std::max(1, n_tokens()))); + + // dynamically increase the token limit if it can fit in the memory limit + const size_t limit_tokens_cur = limit_size > 0 ? std::max(limit_tokens, limit_size / size_per_token) : limit_tokens; + + LLAMA_LOG_INFO(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n", + states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur); + + for (const auto& state : states) { + LLAMA_LOG_INFO(" - prompt %p: %7d tokens, %7d discarded, checkpoints: %2zu, %9.3f MiB\n", + (const void*)&state, state.n_tokens(), state.n_discarded_prompt, state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); + } +} diff --git a/examples/server/server-task.h b/examples/server/server-task.h new file mode 100644 index 00000000..4be2e001 --- /dev/null +++ b/examples/server/server-task.h @@ -0,0 +1,216 @@ +#pragma once +#include "common.h" +#include "llama.h" + +#include +#include +#include +// TODO: prevent including the whole server-common.h as we only use server_tokens +#include "server-common.h" + +using json = nlohmann::ordered_json; + +enum stop_type { + STOP_TYPE_NONE, + STOP_TYPE_EOS, + STOP_TYPE_WORD, + STOP_TYPE_LIMIT, +}; + + + +enum server_task_type { + SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_EMBEDDING, + SERVER_TASK_TYPE_RERANK, + SERVER_TASK_TYPE_INFILL, + SERVER_TASK_TYPE_CANCEL, + SERVER_TASK_TYPE_NEXT_RESPONSE, + SERVER_TASK_TYPE_METRICS, + SERVER_TASK_TYPE_SLOT_SAVE, + SERVER_TASK_TYPE_SLOT_RESTORE, + SERVER_TASK_TYPE_SLOT_ERASE, + SERVER_TASK_TYPE_SET_LORA, +}; + +enum oaicompat_type { + OAICOMPAT_TYPE_NONE, + OAICOMPAT_TYPE_CHAT, + OAICOMPAT_TYPE_COMPLETION, + OAICOMPAT_TYPE_EMBEDDING, + OAICOMPAT_TYPE_ANTHROPIC, +}; + + +struct server_task { + int id = -1; // to be filled by server_queue + int id_multi = -1; + int id_target = -1; + //int id_slot = -1; + + // used by SERVER_TASK_TYPE_INFERENCE + server_tokens tokens; + + server_task_type type; + json data; + + bool infill = false; + bool embedding = false; + + server_task() = default; + server_task(server_task_type type) : type(type) {} + +}; + +struct result_timings { + int32_t prompt_n = -1; + double prompt_ms; + double prompt_per_token_ms; + double prompt_per_second; + + int32_t predicted_n = -1; + double predicted_ms; + double predicted_per_token_ms; + double predicted_per_second; + int32_t n_ctx = 0; + int32_t n_past = 0; + + // Optional speculative metrics - only included when > 0 + int32_t draft_n = 0; + int32_t draft_n_accepted = 0; + + json to_json() const; +}; + +struct server_task_result { + int id = -1; + int id_multi = -1; + + json data; + + bool stop; + bool error; + bool final_result = false; + result_timings timings; + // OAI-compat fields + //bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + common_chat_msg oaicompat_msg; + std::vector oaicompat_msg_diffs; + + int index = 0; + + std::string content; + std::vector tokens; + + bool stream; + bool include_usage; + std::string prompt; + //slot_params generation_params; + + bool truncated; + int32_t n_decoded; + int32_t n_prompt_tokens; + int32_t n_tokens_cached; + bool has_new_line; + std::string stopping_word; + + bool post_sampling_probs = false; + std::vector probs_output; + std::vector response_fields; + + //slot_params generation_params; + + bool verbose = false; + + + int get_index() { + return index; + } + + bool is_stop() { + return true; // in stream mode, final responses are considered stop + } + + json to_json_final(); + + json to_json_partial(); + + json to_json_non_oaicompat_partial(); + + json to_json_non_oaicompat_final(); + + json to_json_oaicompat_partial(); + + json to_json_oaicompat_final(); + + json to_json_oaicompat_chat_partial(); + + json to_json_oaicompat_chat_final(); + + json to_json_oaicompat_chat_stream(); + + json to_json_anthropic_final(); + + json to_json_anthropic_stream(); + + json to_json_anthropic_partial(); +}; + + +struct server_prompt_checkpoint { + llama_pos pos_min; + llama_pos pos_max; + + std::vector data; + + size_t size() const { + return data.size(); + } +}; + + +struct server_prompt { + server_tokens tokens; + int n_kept_prompt; + int n_discarded_prompt; + + std::vector data; + + std::list checkpoints; + + size_t size() const; + + int n_tokens() const { + return tokens.size(); + } +}; + +struct server_prompt_cache { + server_prompt_cache(llama_context* ctx, int32_t limit_size_mib, size_t limit_tokens) { + this->ctx = ctx; + this->limit_size = 1024ull * 1024ull * (limit_size_mib < 0 ? 0 : limit_size_mib); + this->limit_tokens = limit_tokens; + } + + std::list states; + + // in bytes, 0 = no limit + size_t limit_size = 0; + + // in tokens, 0 = no limit + size_t limit_tokens = 0; + llama_context* ctx; + size_t size() const; + + size_t n_tokens() const; + + server_prompt* alloc(const server_prompt& prompt, size_t state_size); + + bool load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot); + + void update(); +}; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 50371d57..54719061 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1,14 +1,13 @@ #pragma warning(disable : 4996) +#include "server-context.h" +#include "server-common.h" #include "chat.h" -#include "utils.hpp" #include "common.h" #include "speculative.h" #include "mtmd.h" #include "sampling.h" -#include "json-schema-to-grammar.h" #include "llama.h" -#include "grammar-parser.h" #include "llama-vocab.h" #ifndef NDEBUG @@ -53,799 +52,12 @@ bool server_verbose = false; bool server_log_json = true; - -enum stop_type { - STOP_TYPE_NONE, - STOP_TYPE_EOS, - STOP_TYPE_WORD, - STOP_TYPE_LIMIT, -}; -enum slot_state { - SLOT_STATE_IDLE, - SLOT_STATE_PROCESSING, -}; - -enum slot_command { - SLOT_COMMAND_NONE, - SLOT_COMMAND_LOAD_PROMPT, - SLOT_COMMAND_RELEASE, -}; - enum server_state { SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet SERVER_STATE_READY, // Server is ready and model is loaded SERVER_STATE_ERROR // An error occurred, load_model failed }; -enum server_task_type { - SERVER_TASK_TYPE_COMPLETION, - SERVER_TASK_TYPE_EMBEDDING, - SERVER_TASK_TYPE_RERANK, - SERVER_TASK_TYPE_INFILL, - SERVER_TASK_TYPE_CANCEL, - SERVER_TASK_TYPE_NEXT_RESPONSE, - SERVER_TASK_TYPE_METRICS, - SERVER_TASK_TYPE_SLOT_SAVE, - SERVER_TASK_TYPE_SLOT_RESTORE, - SERVER_TASK_TYPE_SLOT_ERASE, - SERVER_TASK_TYPE_SET_LORA, -}; - -enum oaicompat_type { - OAICOMPAT_TYPE_NONE, - OAICOMPAT_TYPE_CHAT, - OAICOMPAT_TYPE_COMPLETION, - OAICOMPAT_TYPE_EMBEDDING, - OAICOMPAT_TYPE_ANTHROPIC, -}; - -struct result_timings { - int32_t prompt_n = -1; - double prompt_ms; - double prompt_per_token_ms; - double prompt_per_second; - - int32_t predicted_n = -1; - double predicted_ms; - double predicted_per_token_ms; - double predicted_per_second; - int32_t n_ctx = 0; - int32_t n_past = 0; - - // Optional speculative metrics - only included when > 0 - int32_t draft_n = 0; - int32_t draft_n_accepted = 0; - - json to_json() const { - json base = { - {"prompt_n", prompt_n}, - {"prompt_ms", prompt_ms}, - {"prompt_per_token_ms", prompt_per_token_ms}, - {"prompt_per_second", prompt_per_second}, - - {"predicted_n", predicted_n}, - {"predicted_ms", predicted_ms}, - {"predicted_per_token_ms", predicted_per_token_ms}, - {"predicted_per_second", predicted_per_second}, - - {"n_ctx", n_ctx}, - {"n_past", n_past}, - }; - - if (draft_n > 0) { - base["draft_n"] = draft_n; - base["draft_n_accepted"] = draft_n_accepted; - } - - return base; - } -}; - -struct server_task { - int id = -1; // to be filled by server_queue - int id_multi = -1; - int id_target = -1; - //int id_slot = -1; - - // used by SERVER_TASK_TYPE_INFERENCE - server_tokens tokens; - - server_task_type type; - json data; - - bool infill = false; - bool embedding = false; - - server_task() = default; - server_task(server_task_type type) : type(type) {} - -}; - -struct server_task_result { - int id = -1; - int id_multi = -1; - - json data; - - bool stop; - bool error; - bool final_result = false; - result_timings timings; - // OAI-compat fields - //bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - common_chat_msg oaicompat_msg; - std::vector oaicompat_msg_diffs; - - int index = 0; - - std::string content; - std::vector tokens; - - bool stream; - bool include_usage; - std::string prompt; - //slot_params generation_params; - - bool truncated; - int32_t n_decoded; - int32_t n_prompt_tokens; - int32_t n_tokens_cached; - bool has_new_line; - std::string stopping_word; - - bool post_sampling_probs = false; - std::vector probs_output; - std::vector response_fields; - - //slot_params generation_params; - - bool verbose = false; - - - int get_index() { - return index; - } - - bool is_stop() { - return true; // in stream mode, final responses are considered stop - } - - json to_json_final() { - switch (oaicompat) { - case OAICOMPAT_TYPE_NONE: - return to_json_non_oaicompat_final(); - case OAICOMPAT_TYPE_COMPLETION: - return to_json_oaicompat_final(); - case OAICOMPAT_TYPE_CHAT: - return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat_final(); - case OAICOMPAT_TYPE_ANTHROPIC: - return stream ? to_json_anthropic_stream() : to_json_anthropic_final(); - default: - GGML_ASSERT(false && "Invalid oaicompat_type"); - } - } - - json to_json_partial() { - switch (oaicompat) { - case OAICOMPAT_TYPE_NONE: - return to_json_non_oaicompat_partial(); - case OAICOMPAT_TYPE_COMPLETION: - return to_json_oaicompat_partial(); - case OAICOMPAT_TYPE_CHAT: - return to_json_oaicompat_chat_partial(); - case OAICOMPAT_TYPE_ANTHROPIC: - return to_json_anthropic_partial(); - default: - GGML_ASSERT(false && "Invalid oaicompat_type"); - } - } - - json to_json_non_oaicompat_partial() { - // non-OAI-compat JSON - json res = json{ - {"index", index}, - {"content", content}, - {"tokens", tokens}, - {"stop", false}, - {"id_slot", id_multi}, - {"tokens_predicted", n_decoded}, - {"tokens_evaluated", n_prompt_tokens}, - }; - // populate the timings object when needed (usually for the last response or with timings_per_token enabled) - if (timings.prompt_n > 0) { - res.push_back({ "timings", timings.to_json() }); - } - if (!probs_output.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); - } - return res; - } - - json to_json_non_oaicompat_final() { - json res = json{ - {"index", index}, - {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk - {"tokens", stream ? std::vector {} : tokens}, - {"id_slot", id_multi}, - {"stop", true}, - {"model", oaicompat_model}, - {"tokens_predicted", n_decoded}, - {"tokens_evaluated", n_prompt_tokens}, - //{"generation_settings", default_generation_settings_for_props.to_json()}, - {"prompt", prompt}, - {"has_new_line", has_new_line}, - {"truncated", truncated}, - //{"stop_type", stop_type_to_str(STOP_TYPE_EOS)}, - {"stopping_word", stopping_word}, - {"tokens_cached", n_tokens_cached}, - {"timings", timings.to_json()}, -}; - if (!stream && !probs_output.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); - } - return response_fields.empty() ? res : json_get_nested_values(response_fields, res); - } - - json to_json_oaicompat_partial() { - std::time_t t = std::time(0); - json logprobs = json(nullptr); // OAI default to null - if (probs_output.size() > 0) { - logprobs = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, - }; - } - json res = json{ - {"choices", json::array({ - json{ - {"text", content}, - {"index", index}, - {"logprobs", logprobs}, - {"finish_reason", nullptr}, - } - })}, - {"created", t}, - {"model", oaicompat_model}, - {"object", "text_completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat_partial(); - } - if (timings.prompt_n >= 0) { - res.push_back({ "timings", timings.to_json() }); - } - - return res; - } - - json to_json_oaicompat_final() { - std::time_t t = std::time(0); - json logprobs = json(nullptr); // OAI default to null - if (!stream && probs_output.size() > 0) { - logprobs = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, - }; - } - json finish_reason = "length"; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = "stop"; - } - json res = json{ - {"choices", json::array({ - json{ - {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk - {"index", index}, - {"logprobs", logprobs}, - {"finish_reason", finish_reason}, - } - })}, - {"created", t}, - {"model", oaicompat_model}, - {"object", "text_completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat_final(); - } - if (timings.prompt_n >= 0) { - res.push_back({ "timings", timings.to_json() }); - } - - return res; - } - - json to_json_oaicompat_chat_partial() { - bool first = n_decoded == 1; - std::time_t t = std::time(0); - json choices; - - std::vector deltas; - auto add_delta = [&](const json& delta) { - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", delta}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"object", "chat.completion.chunk"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens}, - }}, - }); - }; - // We have to send an initial update to conform to openai behavior - if (first) { - add_delta({ - {"role", "assistant"}, - {"content", nullptr}, - }); - } - - for (const auto& diff : oaicompat_msg_diffs) { - add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); - } - - if (!deltas.empty()) { - GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1); - - if (probs_output.size() > 0) { - deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, - }; - } - - if (timings.prompt_n >= 0) { - deltas[deltas.size() - 1].push_back({ "timings", timings.to_json() }); - } - } - - return deltas; - } - - json to_json_oaicompat_chat_final() { - std::string finish_reason = "length"; - common_chat_msg msg; - if (!oaicompat_msg.empty()) { - msg = oaicompat_msg; - } - else { - msg.role = "assistant"; - msg.content = content; - } - if (stop) { - finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; - } - - - json choice{ - {"finish_reason", finish_reason}, - {"index", 0}, - {"message", msg.to_json_oaicompat()}, - }; - - if (!stream && probs_output.size() > 0) { - choice["logprobs"] = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, - }; - } - - std::time_t t = std::time(0); - - json res = json{ - {"choices", json::array({choice})}, - {"created", t}, - {"model", oaicompat_model}, - {"object", "chat.completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat_final(); - } - if (timings.prompt_n >= 0) { - res.push_back({ "timings", timings.to_json() }); - } - - return res; - } - - json to_json_oaicompat_chat_stream() { - std::time_t t = std::time(0); - std::string finish_reason = "length"; - if (stop) { - //if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls"; - } - - json deltas = json::array(); - for (const auto& diff : oaicompat_msg_diffs) { - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"object", "chat.completion.chunk"}, - }); - } - - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"object", "chat.completion.chunk"}, - }); - if (include_usage) { - // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage - // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices - deltas.push_back({ - {"choices", json::array()}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"object", "chat.completion.chunk"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens}, - }}, - }); - } - if (timings.prompt_n >= 0) { - deltas.back().push_back({ "timings", timings.to_json() }); - } - // extra fields for debugging purposes - if (verbose && !deltas.empty()) { - deltas.front()["__verbose"] = to_json_non_oaicompat_final(); - } - - return deltas; - } - - json to_json_anthropic_final() { - std::string stop_reason = "max_tokens"; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; - } - - json content_blocks = json::array(); - - common_chat_msg msg; - if (!oaicompat_msg.empty()) { - msg = oaicompat_msg; - } else { - msg.role = "assistant"; - msg.content = content; - } - - - if (!msg.content.empty()) { - content_blocks.push_back({ - {"type", "text"}, - {"text", msg.content} - }); - } - - for (const auto & tool_call : msg.tool_calls) { - json tool_use_block = { - {"type", "tool_use"}, - {"id", tool_call.id}, - {"name", tool_call.name} - }; - - try { - tool_use_block["input"] = json::parse(tool_call.arguments); - } catch (const std::exception &) { - tool_use_block["input"] = json::object(); - } - - content_blocks.push_back(tool_use_block); - } - - json res = { - {"id", oaicompat_cmpl_id}, - {"type", "message"}, - {"role", "assistant"}, - {"content", content_blocks}, - {"model", oaicompat_model}, - {"stop_reason", stop_reason}, - {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}, - {"usage", { - {"input_tokens", n_prompt_tokens}, - {"output_tokens", n_decoded} - }} - }; - - return res; - } - - json to_json_anthropic_stream() { - json events = json::array(); - - std::string stop_reason = "max_tokens"; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; - } - - bool has_text = !oaicompat_msg.content.empty(); - size_t num_tool_calls = oaicompat_msg.tool_calls.size(); - - bool text_block_started = false; - std::set tool_calls_started; - - for (const auto & diff : oaicompat_msg_diffs) { - - if (!diff.content_delta.empty()) { - if (!text_block_started) { - events.push_back({ - {"event", "content_block_start"}, - {"data", { - {"type", "content_block_start"}, - {"index", 0}, - {"content_block", { - {"type", "text"}, - {"text", ""} - }} - }} - }); - text_block_started = true; - } - - events.push_back({ - {"event", "content_block_delta"}, - {"data", { - {"type", "content_block_delta"}, - {"index", 0}, - {"delta", { - {"type", "text_delta"}, - {"text", diff.content_delta} - }} - }} - }); - } - - if (diff.tool_call_index != std::string::npos) { - size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index; - - if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) { - const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index]; - - events.push_back({ - {"event", "content_block_start"}, - {"data", { - {"type", "content_block_start"}, - {"index", content_block_index}, - {"content_block", { - {"type", "tool_use"}, - {"id", full_tool_call.id}, - {"name", full_tool_call.name} - }} - }} - }); - tool_calls_started.insert(diff.tool_call_index); - } - - if (!diff.tool_call_delta.arguments.empty()) { - events.push_back({ - {"event", "content_block_delta"}, - {"data", { - {"type", "content_block_delta"}, - {"index", content_block_index}, - {"delta", { - {"type", "input_json_delta"}, - {"partial_json", diff.tool_call_delta.arguments} - }} - }} - }); - } - } - } - - if (has_text) { - events.push_back({ - {"event", "content_block_stop"}, - {"data", { - {"type", "content_block_stop"}, - {"index", 0} - }} - }); - } - - for (size_t i = 0; i < num_tool_calls; i++) { - size_t content_block_index = (has_text ? 1 : 0) + i; - events.push_back({ - {"event", "content_block_stop"}, - {"data", { - {"type", "content_block_stop"}, - {"index", content_block_index} - }} - }); - } - - events.push_back({ - {"event", "message_delta"}, - {"data", { - {"type", "message_delta"}, - {"delta", { - {"stop_reason", stop_reason}, - {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)} - }}, - {"usage", { - {"output_tokens", n_decoded} - }} - }} - }); - - events.push_back({ - {"event", "message_stop"}, - {"data", { - {"type", "message_stop"} - }} - }); - - // extra fields for debugging purposes - if (verbose && !events.empty()) { - events.front()["data"]["__verbose"] = to_json_non_oaicompat_final(); - } - // Don't add timings for Anthropic API (breaks spec compliance) - if (oaicompat != OAICOMPAT_TYPE_ANTHROPIC && timings.prompt_n >= 0 && !events.empty()) { - events.back()["data"]["timings"] = timings.to_json(); - } - - return events; - } - - json to_json_anthropic_partial() { - json events = json::array(); - bool first = n_decoded == 1; - static bool text_block_started = false; - - if (first) { - text_block_started = false; - - events.push_back({ - {"event", "message_start"}, - {"data", { - {"type", "message_start"}, - {"message", { - {"id", oaicompat_cmpl_id}, - {"type", "message"}, - {"role", "assistant"}, - {"content", json::array()}, - {"model", oaicompat_model}, - {"stop_reason", nullptr}, - {"stop_sequence", nullptr}, - {"usage", { - {"input_tokens", n_prompt_tokens}, - {"output_tokens", 0} - }} - }} - }} - }); - } - - for (const auto & diff : oaicompat_msg_diffs) { - if (!diff.content_delta.empty()) { - if (!text_block_started) { - events.push_back({ - {"event", "content_block_start"}, - {"data", { - {"type", "content_block_start"}, - {"index", 0}, - {"content_block", { - {"type", "text"}, - {"text", ""} - }} - }} - }); - text_block_started = true; - } - - events.push_back({ - {"event", "content_block_delta"}, - {"data", { - {"type", "content_block_delta"}, - {"index", 0}, - {"delta", { - {"type", "text_delta"}, - {"text", diff.content_delta} - }} - }} - }); - } - - if (diff.tool_call_index != std::string::npos) { - size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index; - - if (!diff.tool_call_delta.name.empty()) { - events.push_back({ - {"event", "content_block_start"}, - {"data", { - {"type", "content_block_start"}, - {"index", content_block_index}, - {"content_block", { - {"type", "tool_use"}, - {"id", diff.tool_call_delta.id}, - {"name", diff.tool_call_delta.name} - }} - }} - }); - } - - if (!diff.tool_call_delta.arguments.empty()) { - events.push_back({ - {"event", "content_block_delta"}, - {"data", { - {"type", "content_block_delta"}, - {"index", content_block_index}, - {"delta", { - {"type", "input_json_delta"}, - {"partial_json", diff.tool_call_delta.arguments} - }} - }} - }); - } - } - } - - if (verbose && !events.empty() && first) { - events.front()["data"]["__verbose"] = to_json_non_oaicompat_partial(); - } - - if (timings.prompt_n >= 0 && !events.empty()) { - events.back()["data"]["timings"] = timings.to_json(); - } - - //if (is_progress && !events.empty()) { - // events.back()["data"]["prompt_progress"] = progress.to_json(); - //} - - return events; - } -}; static inline std::string stop_type_to_str(stop_type type) { switch (type) { @@ -857,45 +69,6 @@ static inline std::string stop_type_to_str(stop_type type) { } -struct server_task_multi { - int id = -1; - - std::set subtasks_remaining; - std::vector results; -}; - -struct slot_params { - bool stream = true; - bool include_usage = false; - bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt - - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half - int32_t n_predict = -1; // new tokens to predict - - std::vector antiprompt; - - bool timings_per_token = false; - bool post_sampling_probs = false; - json input_prefix; - json input_suffix; - - // speculative decoding parameters - struct { - int n_max = 16; // max drafted tokens - int n_min = 0; // min drafted tokens to accept - float p_min = 0.75f; // min probability required to accept a token in the draft - } speculative; - - // OAI-compat fields - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_syntax oaicompat_chat_syntax; - -}; - - inline std::string get_model_name(std::string path) { std::string filename = path.substr(path.find_last_of("/\\") + 1); @@ -903,3351 +76,6 @@ inline std::string get_model_name(std::string path) }; -struct server_prompt_checkpoint { - llama_pos pos_min; - llama_pos pos_max; - - std::vector data; - - size_t size() const { - return data.size(); - } -}; - - -struct server_prompt { - server_tokens tokens; - int n_kept_prompt; - int n_discarded_prompt; - - std::vector data; - - std::list checkpoints; - - size_t size() const { - size_t res = data.size(); - - for (const auto& checkpoint : checkpoints) { - res += checkpoint.size(); - } - - return res; - } - - int n_tokens() const { - return tokens.size(); - } -}; - -struct server_prompt_cache { - server_prompt_cache(llama_context * ctx,int32_t limit_size_mib, size_t limit_tokens) { - this->ctx = ctx; - this->limit_size = 1024ull * 1024ull * (limit_size_mib < 0 ? 0 : limit_size_mib); - this->limit_tokens = limit_tokens; - } - - std::list states; - - // in bytes, 0 = no limit - size_t limit_size = 0; - - // in tokens, 0 = no limit - size_t limit_tokens = 0; - llama_context* ctx; - size_t size() const { - size_t res = 0; - - for (const auto& state : states) { - res += state.size(); - } - - return res; - } - - size_t n_tokens() const { - size_t res = 0; - - for (const auto& state : states) { - res += state.n_tokens(); - } - return res; - } - - server_prompt* alloc(const server_prompt& prompt, size_t state_size) { - for (auto it = states.begin(); it != states.end();) { - auto tokens_ctx_shift = server_tokens(prompt.tokens.get_text_tokens(), false); // copy cache tokens - tokens_ctx_shift.discard_n_tokens(prompt.n_kept_prompt, prompt.n_discarded_prompt); - auto prefix = it->tokens.get_common_prefix(ctx, tokens_ctx_shift); - const size_t len = prefix.first; - const size_t len_prompt = prefix.second; - // first check if the current state is contained fully in the cache - if (len_prompt == tokens_ctx_shift.size()) { - LLAMA_LOG_INFO("%s", " - prompt is already in the cache, skipping\n"); - return nullptr; - } - // next, remove any cached prompts that are fully contained in the current prompt - else if(len == it->tokens.size()) { - LLAMA_LOG_INFO(" - removing obsolete cached prompt with length %d\n", (int)len); - it = states.erase(it); - } - else { - ++it; - } - } - - std::vector state_data; - - // check if we can allocate enough memory for the new state - try { - state_data.resize(state_size); - } - catch (const std::bad_alloc& e) { - LLAMA_LOG_INFO("failed to allocate memory for prompt cache state: %s\n", e.what()); - - limit_size = std::max(1, 0.4 * size()); - - LLAMA_LOG_INFO(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0)); - - update(); - - return nullptr; - } - - // TODO: for some reason we can't copy server_tokens, so we have to do this workaround - auto& cur = states.emplace_back(); - cur = { - /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), - /*.n_keep =*/ prompt.n_kept_prompt, - /*.n_discarded_prompt =*/ prompt.n_discarded_prompt, - /*.data =*/ std::move(state_data), - /*.checkpoints =*/ prompt.checkpoints, - }; - - return &cur; - } - - bool load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot) { - const auto lcp_best = prompt.tokens.get_common_prefix(ctx, tokens_new); - - float f_keep_best = float(lcp_best.second) / prompt.tokens.size(); - float sim_best = prompt.tokens.get_tokens_similarity(ctx, tokens_new, prompt.n_kept_prompt, prompt.n_discarded_prompt); - LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, prompt.n_kept_prompt, prompt.n_discarded_prompt); - - auto it_best = states.end(); - - // find the most similar cached prompt, that would also preserve the most context - for (auto it = states.begin(); it != states.end(); ++it) { - const auto lcp_cur = it->tokens.get_common_prefix(ctx, tokens_new); - const float f_keep_cur = float(lcp_cur.first) / it->tokens.size(); - const float sim_cur = it->tokens.get_tokens_similarity(ctx, tokens_new, it->n_kept_prompt, it->n_discarded_prompt); - if (sim_best < sim_cur) { - f_keep_best = f_keep_cur; - sim_best = sim_cur; - it_best = it; - } - } - - if (it_best != states.end()) { - LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, it_best->n_kept_prompt, it_best->n_discarded_prompt); - const size_t size = it_best->data.size(); - const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot); - if (n != size) { - LLAMA_LOG_INFO("failed to restore state with size %zu\n", size); - return false; - } - - it_best->data.clear(); - it_best->data.shrink_to_fit(); - - prompt = std::move(*it_best); - - states.erase(it_best); - } - - return true; - } - - void update() { - if (limit_size > 0) { - // always keep at least one state, regardless of the limits - while (states.size() > 1 && size() > limit_size) { - if (states.empty()) { - break; - } - - LLAMA_LOG_INFO(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); - - states.pop_front(); - } - } - - // average size per token - const float size_per_token = std::max(1.0f, float(size()) / (std::max(1, n_tokens()))); - - // dynamically increase the token limit if it can fit in the memory limit - const size_t limit_tokens_cur = limit_size > 0 ? std::max(limit_tokens, limit_size / size_per_token) : limit_tokens; - - //if (limit_tokens > 0) { - // - // while (states.size() > 1 && n_tokens() > limit_tokens_cur) { - // if (states.empty()) { - // break; - // } - - // LLAMA_LOG_INFO(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n", - // limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0)); - - // states.pop_front(); - // } - //} - - LLAMA_LOG_INFO(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n", - states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur); - - for (const auto& state : states) { - LLAMA_LOG_INFO(" - prompt %p: %7d tokens, %7d discarded, checkpoints: %2zu, %9.3f MiB\n", - (const void*)&state, state.n_tokens(), state.n_discarded_prompt, state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); - } - } -}; - - -struct server_slot { - int id; - int id_task = -1; - int id_multi = -1; - - struct slot_params params; - - slot_state state = SLOT_STATE_IDLE; - slot_command command = SLOT_COMMAND_NONE; - - llama_context* ctx = nullptr; - // used to determine the slot that has been used the longest - int64_t t_last_used = -1; - - std::unique_ptr task; - - // generation props - int32_t n_ctx = 0; // context size per slot - int32_t n_past = 0; - int32_t n_past_prompt = 0; - int32_t n_decoded = 0; - int32_t n_remaining = -1; - int32_t n_discarded_prompt = 0; - int32_t n_kept_prompt = 0; - - int32_t i_batch = -1; - int32_t n_predict = -1; // TODO: disambiguate from params.n_predict - - int32_t n_prompt_tokens = 0; - int32_t n_prompt_tokens_processed = 0; - - json prompt; // can be either a string, array of strings or array of token ids - - // when a task is submitted, we first tokenize the prompt and store it here - server_tokens prompt_tokens; - server_tokens cache_tokens; - - std::string generated_text; - - std::vector generated_token_probs; - common_chat_msg chat_msg; - - bool infill = false; - bool embedding = false; - bool has_next_token = true; - bool truncated = false; - bool stopped_eos = false; - bool stopped_word = false; - bool stopped_limit = false; - - bool oaicompat = false; - - std::string oaicompat_model; - std::string stopping_word; - stop_type stop; - - server_prompt server_cached_prompt; - - void prompt_save(server_prompt_cache & prompt_cache) const { - assert(server_cached_prompt.data.size() == 0); - - const size_t cur_size = llama_state_seq_get_size(ctx, id); - - LLAMA_LOG_INFO(" - saving prompt with length %d, total state size = %.3f MiB\n", - (int)server_cached_prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); - - auto* cur = prompt_cache.alloc(server_cached_prompt, cur_size); - if (cur == nullptr) { - return; - } - - llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id); - } - - void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) { - bool res = prompt_cache.load(server_cached_prompt, tokens, ctx, id); - if (!res) { - LLAMA_LOG_INFO("failed to load prompt from cache\n"); - } - } - - - // sampling - llama_token sampled; - struct llama_sampling_params sparams; - llama_sampling_context * ctx_sampling = nullptr; - json json_schema; - - common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - std::vector generated_tool_call_ids; - - int32_t ga_i = 0; // group-attention state - int32_t ga_n = 1; // group-attention factor - int32_t ga_w = 512; // group-attention width - - // multimodal - mtmd_context * mctx = nullptr; - - // speculative decoding - struct llama_speculative * spec = nullptr; - llama_context * ctx_dft = nullptr; - llama_batch batch_spec = {}; - - // speculative decoding stats - int32_t n_draft_total = 0; // Total draft tokens generated - int32_t n_draft_accepted = 0; // Draft tokens actually accepted - - int32_t n_past_se = 0; // self-extend - - // stats - size_t n_sent_text = 0; // number of sent text character - size_t n_sent_token_probs = 0; - - int64_t t_start_process_prompt; - int64_t t_start_generation; - - double t_prompt_processing; // ms - double t_token_generation; // ms - - void reset() { - n_prompt_tokens = 0; - generated_text = ""; - truncated = false; - stopped_eos = false; - stopped_word = false; - stopped_limit = false; - stopping_word = ""; - n_past = 0; - n_sent_text = 0; - n_sent_token_probs = 0; - infill = false; - ga_i = 0; - n_past_se = 0; - chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - - generated_token_probs.clear(); - - - // Reset speculative decoding stats - n_draft_total = 0; - n_draft_accepted = 0; - chat_msg = {}; - json_schema = json(); - generated_tool_call_ids.clear(); - - task.reset(); - } - - bool has_budget(gpt_params &global_params) { - if (params.n_predict == -1 && global_params.n_predict == -1) { - return true; // limitless - } - - n_remaining = -1; - - if (params.n_predict != -1) { - n_remaining = params.n_predict - n_decoded; - } else if (global_params.n_predict != -1) { - n_remaining = global_params.n_predict - n_decoded; - } - - return n_remaining > 0; // no budget - } - - bool available() const { - return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE; - } - - bool is_processing() const { - return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING; - } - - void add_token_string(const completion_token_output & token) { - if (command == SLOT_COMMAND_RELEASE) { - return; - } - generated_token_probs.push_back(token); - } - - void release() { - if (state == SLOT_STATE_PROCESSING) { - t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; - command = SLOT_COMMAND_RELEASE; - task.reset(); - } - } - - - json get_formated_timings() const { - return json { - {"prompt_n", n_prompt_tokens_processed}, - {"prompt_ms", t_prompt_processing}, - {"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed}, - {"prompt_per_second", 1e3 / t_prompt_processing * n_prompt_tokens_processed}, - - {"predicted_n", n_decoded}, - {"predicted_ms", t_token_generation}, - {"predicted_per_token_ms", t_token_generation / n_decoded}, - {"predicted_per_second", 1e3 / t_token_generation * n_decoded}, - - {"n_ctx", n_ctx}, - {"n_past", n_past}, - }; - } - - result_timings get_timings() const { - result_timings timings; - timings.prompt_n = n_prompt_tokens_processed; - timings.prompt_ms = t_prompt_processing; - timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; - timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - - timings.predicted_n = n_decoded; - timings.predicted_ms = t_token_generation; - timings.predicted_per_token_ms = t_token_generation / n_decoded; - timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; - - timings.n_ctx = n_ctx; - timings.n_past = n_past; - - - // Add speculative metrics - if (n_draft_total > 0) { - timings.draft_n = n_draft_total; - timings.draft_n_accepted = n_draft_accepted; - } - - return timings; - } - - const common_chat_msg& update_chat_msg(std::vector& diffs) { - auto previous_msg = chat_msg; - auto new_msg = common_chat_parse( - generated_text, - /* is_partial= */ stop != STOP_TYPE_EOS, - params.oaicompat_chat_syntax); - if (!new_msg.empty()) { - new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id); - chat_msg = new_msg; - diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); - } - //LLAMA_LOG_DEBUG("Parsing chat message: %s\n", generated_text.c_str()); - //LLAMA_LOG_DEBUG("Parsing chat message: %s\n", chat_msg.reasoning_content.c_str()); - //LLAMA_LOG_DEBUG("Parsing chat message: %s\n", chat_msg.content.c_str()); - return chat_msg; - } - - - size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { - size_t stop_pos = std::string::npos; - - for (const std::string & word : params.antiprompt) { - size_t pos; - - if (is_full_stop) { - const size_t tmp = word.size() + last_token_size; - const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; - - pos = text.find(word, from_pos); - } else { - pos = string_find_partial_stop(text, word); - } - - if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { - if (is_full_stop) { - stopped_word = true; - stopping_word = word; - has_next_token = false; - } - stop_pos = pos; - } - } - - return stop_pos; - } - - void print_timings() const { - char buffer[512]; - - double t_token = t_prompt_processing / n_prompt_tokens_processed; - double n_tokens_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - - snprintf(buffer, 512, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)", - t_prompt_processing, n_prompt_tokens_processed, - t_token, n_tokens_second); - - LOG_INFO(buffer, { - {"id_slot", id}, - {"id_task", id_task}, - {"t_prompt_processing", t_prompt_processing}, - {"n_prompt_tokens_processed", n_prompt_tokens_processed}, - {"t_token", t_token}, - {"n_tokens_second", n_tokens_second}, - }); - - t_token = t_token_generation / n_decoded; - n_tokens_second = 1e3 / t_token_generation * n_decoded; - - snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)", - t_token_generation, n_decoded, - t_token, n_tokens_second); - - LOG_INFO(buffer, { - {"id_slot", id}, - {"id_task", id_task}, - {"t_token_generation", t_token_generation}, - {"n_decoded", n_decoded}, - {"t_token", t_token}, - {"n_tokens_second", n_tokens_second}, - }); - - snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation); - - LOG_INFO(buffer, { - {"id_slot", id}, - {"id_task", id_task}, - {"t_prompt_processing", t_prompt_processing}, - {"t_token_generation", t_token_generation}, - {"t_total", t_prompt_processing + t_token_generation}, - }); - } -}; - -struct server_metrics { - int64_t t_start = 0; - - uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; - - uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; - - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - - void init() { - t_start = ggml_time_us(); - } - - void on_prompt_eval(const server_slot & slot) { - n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; - n_prompt_tokens_processed += slot.n_prompt_tokens_processed; - t_prompt_processing += slot.t_prompt_processing; - t_prompt_processing_total += slot.t_prompt_processing; - } - - void on_prediction(const server_slot & slot) { - n_tokens_predicted_total += slot.n_decoded; - n_tokens_predicted += slot.n_decoded; - t_tokens_generation += slot.t_token_generation; - t_tokens_generation_total += slot.t_token_generation; - } - - void reset_bucket() { - n_prompt_tokens_processed = 0; - t_prompt_processing = 0; - n_tokens_predicted = 0; - t_tokens_generation = 0; - } -}; - -struct server_queue { - int id = 0; - bool running; - - // queues - std::vector queue_tasks; - std::vector queue_tasks_deferred; - - std::vector queue_multitasks; - - std::mutex mutex_tasks; - std::condition_variable condition_tasks; - - // callback functions - std::function callback_new_task; - std::function callback_finish_multitask; - std::function callback_update_slots; - - // Add a new task to the end of the queue - int post(server_task task) { - std::unique_lock lock(mutex_tasks); - if (task.id == -1) { - task.id = id++; - LOG_VERBOSE("new task id", {{"new_id", task.id}}); - } - queue_tasks.push_back(std::move(task)); - condition_tasks.notify_one(); - return task.id; - } - - // Add a new task, but defer until one slot is available - void defer(server_task && task) { - std::unique_lock lock(mutex_tasks); - queue_tasks_deferred.push_back(std::move(task)); - } - - // Get the next id for creating anew task - int get_new_id() { - std::unique_lock lock(mutex_tasks); - int new_id = id++; - LOG_VERBOSE("new task id", {{"new_id", new_id}}); - return new_id; - } - - // Register function to process a new task - void on_new_task(std::function callback) { - callback_new_task = std::move(callback); - } - - // Register function to process a multitask when it is finished - void on_finish_multitask(std::function callback) { - callback_finish_multitask = std::move(callback); - } - - // Register the function to be called when all slots data is ready to be processed - void on_update_slots(std::function callback) { - callback_update_slots = std::move(callback); - } - - // Call when the state of one slot is changed - void notify_slot_changed() { - // move deferred tasks back to main loop - std::unique_lock lock(mutex_tasks); - for (auto & task : queue_tasks_deferred) { - queue_tasks.push_back(std::move(task)); - } - queue_tasks_deferred.clear(); - } - - // end the start_loop routine - void terminate() { - std::unique_lock lock(mutex_tasks); - running = false; - condition_tasks.notify_all(); - } - - /** - * Main loop consists of these steps: - * - Wait until a new task arrives - * - Process the task (i.e. maybe copy data into slot) - * - Check if multitask is finished - * - Update all slots - */ - void start_loop() { - running = true; - - while (true) { - LOG_VERBOSE("new task may arrive", {}); - - while (true) { - std::unique_lock lock(mutex_tasks); - if (queue_tasks.empty()) { - lock.unlock(); - break; - } - server_task task = std::move(queue_tasks.front()); - queue_tasks.erase(queue_tasks.begin()); - lock.unlock(); - LOG_VERBOSE("callback_new_task", {{"id_task", task.id}}); - callback_new_task(std::move(task)); - } - - LOG_VERBOSE("update_multitasks", {}); - - // check if we have any finished multitasks - auto queue_iterator = queue_multitasks.begin(); - while (queue_iterator != queue_multitasks.end()) { - if (queue_iterator->subtasks_remaining.empty()) { - // all subtasks done == multitask is done - server_task_multi current_multitask = *queue_iterator; - callback_finish_multitask(current_multitask); - // remove this multitask - queue_iterator = queue_multitasks.erase(queue_iterator); - } else { - ++queue_iterator; - } - } - - // all tasks in the current loop is processed, slots data is now ready - LOG_VERBOSE("callback_update_slots", {}); - - callback_update_slots(); - - LOG_VERBOSE("wait for new task", {}); - { - std::unique_lock lock(mutex_tasks); - if (queue_tasks.empty()) { - if (!running) { - LOG_VERBOSE("ending start_loop", {}); - return; - } - condition_tasks.wait(lock, [&]{ - return (!queue_tasks.empty() || !running); - }); - } - } - } - } - - // - // functions to manage multitasks - // - - // add a multitask by specifying the id of all subtask (subtask is a server_task) - void add_multitask(int id_multi, std::vector & sub_ids) { - std::lock_guard lock(mutex_tasks); - server_task_multi multi; - multi.id = id_multi; - std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); - queue_multitasks.push_back(multi); - } - - // updatethe remaining subtasks, while appending results to multitask - void update_multitask(int id_multi, int id_sub, server_task_result & result) { - std::lock_guard lock(mutex_tasks); - for (auto & multitask : queue_multitasks) { - if (multitask.id == id_multi) { - multitask.subtasks_remaining.erase(id_sub); - multitask.results.push_back(result); - } - } - } -}; - -struct server_response { - typedef std::function callback_multitask_t; - callback_multitask_t callback_update_multitask; - - // for keeping track of all tasks waiting for the result - std::set waiting_task_ids; - - // the main result queue - std::vector queue_results; - - std::mutex mutex_results; - std::condition_variable condition_results; - - // add the id_task to the list of tasks waiting for response - void add_waiting_task_id(int id_task) { - LOG_VERBOSE("waiting for task id", {{"id_task", id_task}}); - - std::unique_lock lock(mutex_results); - waiting_task_ids.insert(id_task); - } - - // when the request is finished, we can remove task associated with it - void remove_waiting_task_id(int id_task) { - LOG_VERBOSE("remove waiting for task id", {{"id_task", id_task}}); - - std::unique_lock lock(mutex_results); - waiting_task_ids.erase(id_task); - } - - // This function blocks the thread until there is a response for this id_task - server_task_result recv(int id_task) { - while (true) { - std::unique_lock lock(mutex_results); - condition_results.wait(lock, [&]{ - return !queue_results.empty(); - }); - - for (int i = 0; i < (int) queue_results.size(); i++) { - if (queue_results[i].id == id_task) { - assert(queue_results[i].id_multi == -1); - server_task_result res = queue_results[i]; - queue_results.erase(queue_results.begin() + i); - return res; - } - } - } - - // should never reach here - } - - // Register the function to update multitask - void on_multitask_update(callback_multitask_t callback) { - callback_update_multitask = std::move(callback); - } - - // Send a new result to a waiting id_task - void send(server_task_result result) { - LOG_VERBOSE("send new result", {{"id_task", result.id}}); - - std::unique_lock lock(mutex_results); - for (const auto & id_task : waiting_task_ids) { - // LOG_TEE("waiting task id %i \n", id_task); - // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result - if (result.id_multi == id_task) { - LOG_VERBOSE("callback_update_multitask", {{"id_task", id_task}}); - callback_update_multitask(id_task, result.id, result); - continue; - } - - if (result.id == id_task) { - LOG_VERBOSE("queue_results.push_back", {{"id_task", id_task}}); - queue_results.push_back(result); - condition_results.notify_all(); - return; - } - } - } -}; - -struct server_context { - llama_model * model = nullptr; - llama_context * ctx = nullptr; - std::vector lora_adapters; - - gpt_params params; - - llama_batch batch; - - bool clean_kv_cache = true; - bool add_bos_token = true; - bool has_eos_token = false; - - // multimodal - mtmd_context * mctx = nullptr; - - // For speculative decoding - llama_model * model_draft = nullptr; - llama_context * ctx_draft = nullptr; - llama_context_params cparams_dft; - - int32_t n_ctx; // total context for all clients / slots - - // system prompt - bool system_need_update = false; - - std::string system_prompt; - std::vector system_tokens; - - // slots / clients - std::vector slots; - json default_generation_settings_for_props; - - server_queue queue_tasks; - server_response queue_results; - - std::unique_ptr prompt_cache; - - server_metrics metrics; - - common_chat_templates_ptr chat_templates; - oaicompat_parser_options oai_parser_opt; - // Necessary similarity of prompt for slot selection - float slot_prompt_similarity = 0.0f; - int32_t cache_ram_n_min = 0; - float cache_ram_similarity = 0.5f; - - ~server_context() { - if (ctx) { - llama_free(ctx); - ctx = nullptr; - } - - if (model) { - llama_free_model(model); - model = nullptr; - } - // Free multimodal - mtmd_free(mctx); - // Free draft model and context if they exist - if (ctx_draft) { - llama_free(ctx_draft); - ctx_draft = nullptr; - } - if (model_draft) { - llama_free_model(model_draft); - model_draft = nullptr; - } - - // Clear any sampling context - for (server_slot & slot : slots) { - if (slot.ctx_sampling != nullptr) { - llama_sampling_free(slot.ctx_sampling); - } - if (slot.ctx_dft) { - llama_free(slot.ctx_dft); - } - if (slot.spec) { - llama_speculative_free(slot.spec); - } - llama_batch_free(slot.batch_spec); - } - - llama_batch_free(batch); - } - - bool load_model(const gpt_params & params_) { - params = params_; - - llama_init_result llama_init = llama_init_from_gpt_params(params); - - model = llama_init.model; - ctx = llama_init.context; - lora_adapters = llama_init.lora_adapters; - - if (model == nullptr) { - LOG_ERROR("unable to load model", {{"model", params.model}}); - return false; - } - - n_ctx = llama_n_ctx(ctx); - - add_bos_token = llama_should_add_bos_token(model); - has_eos_token = llama_add_eos_token(model) != 1; - - chat_templates = common_chat_templates_init(model, params.chat_template); - try { - common_chat_format_example(chat_templates.get(), params.use_jinja, {}); - } - catch (const std::exception& e) { - LOG_WARNING("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); - chat_templates = common_chat_templates_init(model, "chatml"); - } - - bool has_draft_model = !params.model_draft.empty() || !params.draft_params.empty(); - std::string & mmproj_path = params.mmproj.path; - if (!mmproj_path.empty()) { - mtmd_context_params mparams = mtmd_context_params_default(); - mparams.use_gpu = params.mmproj_use_gpu; - mparams.print_timings = false; - mparams.n_threads = params.n_threads; - mparams.flash_attn_type = params.flash_attn? LLAMA_FLASH_ATTN_TYPE_ENABLED: LLAMA_FLASH_ATTN_TYPE_DISABLED; - mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; - mparams.image_min_tokens = params.image_min_tokens; - mparams.image_max_tokens = params.image_max_tokens; - mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); - if (mctx == nullptr) { - LOG_ERROR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); - return false; - } - LOG_INFO("loaded multimodal model, '%s'\n", mmproj_path.c_str()); - - if (params.ctx_shift) { - params.ctx_shift = false; - LOG_WARNING("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); - } - - //if (params.n_cache_reuse) { - // params_base.n_cache_reuse = 0; - // SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); - //} - - if (has_draft_model) { - LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal"); - return false; - } - } - // Load draft model for speculative decoding if specified - if (has_draft_model) { - LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n"); - - gpt_params params_dft; - params_dft.devices = params.devices_draft; - params_dft.model = params.model_draft; - params_dft.n_gpu_layers = params.n_gpu_layers_draft; - params_dft.cache_type_k = params.cache_type_k_draft.empty() ? params.cache_type_k : params.cache_type_k_draft; - params_dft.cache_type_v = params.cache_type_v_draft.empty() ? params.cache_type_v : params.cache_type_v_draft; - params_dft.flash_attn = params.flash_attn; - if (!params.draft_params.empty()) { - auto [argc, argv] = parse_command_line("llama-server "+params.draft_params); - if (!gpt_params_parse(argc, argv, params_dft)) { - gpt_params_print_usage(argc, argv, params_dft); - free_command_line(argc, argv); - return false; - }; - free_command_line(argc, argv); - } - LOG_INFO("", { {"model", params_dft.model} }); - if (params_dft.n_ctx == 0) { - params_dft.n_ctx = params.n_ctx_draft; - } - params_dft.n_ctx = params_dft.n_ctx == 0 ? params.n_ctx / params.n_parallel : params_dft.n_ctx; - params_dft.n_parallel = 1; - - llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft); - - llama_model * model_dft = llama_init_dft.model; - if (model_dft == nullptr) { - LOG_ERROR("failed to load draft model", {{"model", params.model_draft}}); - return false; - } - - if (!llama_speculative_are_compatible(ctx, llama_init_dft.context)) { - LOG_INFO("the draft model is not compatible with the target model. tokens will be translated between the draft and target models.", {{}}); - } - - const int n_ctx_dft = llama_n_ctx(llama_init_dft.context); - - cparams_dft = llama_context_params_from_gpt_params(params_dft); - cparams_dft.n_batch = n_ctx_dft; - - model_draft = llama_init_dft.model; - ctx_draft = llama_init_dft.context; - } - return true; - } - - - void init() { - const int32_t n_ctx_slot = n_ctx / params.n_parallel; - - LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}}); - - for (int i = 0; i < params.n_parallel; i++) { - server_slot slot; - - slot.id = i; - slot.ctx = ctx; - slot.n_ctx = n_ctx_slot; - slot.n_predict = params.n_predict; - slot.mctx = mctx; - slot.cache_tokens.has_mtmd = mctx != nullptr; - - LOG_INFO("new slot", { - {"id_slot", slot.id}, - {"n_ctx_slot", slot.n_ctx} - }); - - const int ga_n = params.grp_attn_n; - const int ga_w = params.grp_attn_w; - - if (ga_n != 1) { - GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT - GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT - //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT - //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT - - LOG_INFO("slot self-extend", { - {"id_slot", slot.id}, - {"ga_n", ga_n}, - {"ga_w", ga_w} - }); - } - - slot.ga_i = 0; - slot.ga_n = ga_n; - slot.ga_w = ga_w; - - slot.sparams = params.sparams; - - // Initialize speculative decoding if a draft model is loaded - if (ctx_draft) { - slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); - // slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft); // initialized twice - slot.ctx_dft = ctx_draft; - if (slot.ctx_dft == nullptr) { - LOG_ERROR("failed to create draft context", {}); - return; - } - - slot.spec = llama_speculative_init(ctx, slot.ctx_dft); - if (slot.spec == nullptr) { - LOG_ERROR("failed to create speculator", {}); - return; - } - for (auto & pair : params.replacements_draft) { - llama_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); - } - - } - - slot.reset(); - - slots.push_back(std::move(slot)); - } - - default_generation_settings_for_props = get_formated_generation(slots.front()); - default_generation_settings_for_props["seed"] = -1; - - // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens - // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) - { - const int32_t n_batch = llama_n_batch(ctx); - - // only a single seq_id per token is needed - batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1); - } - - metrics.init(); - - if (params.cache_ram_mib != 0) { - if (params.cache_ram_mib < 0) { - LLAMA_LOG_INFO("prompt cache is enabled, size limit: %s\n", "no limit"); - } - else { - LLAMA_LOG_INFO("prompt cache is enabled, size limit: %d MiB\n", params.cache_ram_mib); - } - LLAMA_LOG_INFO("%s", "use `--cache-ram 0` to disable the prompt cache\n"); - // only apply ram size limit. No token limit for now. - prompt_cache = std::make_unique(ctx,params.cache_ram_mib, 0); - } - else { - LLAMA_LOG_INFO("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); - } - - // thinking is enabled if: - // 1. It's not explicitly disabled (reasoning_budget == 0) - // 2. The chat template supports it - const bool enable_thinking = params.use_jinja && params.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); - //LLAMA_LOG_INFO("Enable thinking? %d\n", enable_thinking); - - oai_parser_opt = { - /* use_jinja */ params.use_jinja, - /* prefill_assistant */ params.prefill_assistant, - /* reasoning_format */ params.reasoning_format, - /* chat_template_kwargs */ params.default_template_kwargs, - /* common_chat_templates */ chat_templates.get(), - /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, - /* allow_audio */ mctx ? mtmd_support_audio(mctx) : false, - /* enable_thinking */ enable_thinking, - }; - } - - std::vector tokenize(const json & json_prompt, bool add_special) const { - // TODO: currently, we tokenize using special tokens by default - // this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216) - // but it's better compared to completely ignoring ChatML and other chat templates - const bool TMP_FORCE_SPECIAL = true; - - // If `add_bos` is true, we only add BOS, when json_prompt is a string, - // or the first element of the json_prompt array is a string. - std::vector prompt_tokens; - - if (json_prompt.is_array()) { - bool first = true; - for (const auto & p : json_prompt) { - if (p.is_string()) { - auto s = p.template get(); - - std::vector p; - if (first) { - p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL); - first = false; - } else { - p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL); - } - - prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); - } else { - if (first) { - first = false; - } - - prompt_tokens.push_back(p.template get()); - } - } - } else { - auto s = json_prompt.template get(); - prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL); - } - - return prompt_tokens; - } - - server_slot * get_slot_by_id(int id) { - for (server_slot & slot : slots) { - if (slot.id == id) { - return &slot; - } - } - - return nullptr; - } - - server_slot * get_available_slot(const server_task & task) { - server_slot * ret = nullptr; - bool update_cache = false; - - // find the slot that has at least n% prompt similarity - if (ret == nullptr && slot_prompt_similarity != 0.0f) { - int max_lcp_len = 0; - float sim_best = 0; - - for (server_slot & slot : slots) { - // skip the slot if it is not available - if (!slot.available()) { - continue; - } - const auto & cache_tokens = slot.cache_tokens; - // skip the slot if it does not contains prompt - if (cache_tokens.empty()) { - continue; - } - // length of the Longest Common Prefix between the current slot's prompt and the input prompt - auto lcp_len = cache_tokens.get_common_prefix(slot.ctx,task.tokens); - // fraction of the Longest Common Prefix length with respect to the input prompt and cached prompt length - float sim_cur = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, 0, 0); - // handle context shift - if (slot.ga_n == 1 && slot.n_discarded_prompt > 0 && task.tokens.size()>=slot.n_ctx) { - float sim_cur_ctx_shift = cache_tokens.get_tokens_similarity(slot.ctx, task.tokens, slot.n_kept_prompt, slot.n_discarded_prompt); - if (sim_cur_ctx_shift > sim_cur) { - sim_cur = sim_cur_ctx_shift; - } - } - - // select the current slot if the criteria match - if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { - sim_best = sim_cur; - max_lcp_len = lcp_len.first; - ret = &slot; - } - } - if (ret != nullptr) { - LOG_VERBOSE("selected slot by lcp similarity", { - {"id_slot", ret->id}, - {"max_lcp_len", max_lcp_len}, - {"similarity", sim_best}, - }); - } - } - - // find the slot that has been least recently used - if (ret == nullptr) { - int64_t t_last = ggml_time_us(); - for (server_slot & slot : slots) { - // skip the slot if it is not available - if (!slot.available()) { - continue; - } - // select the current slot if the criteria match - if (slot.t_last_used < t_last) { - t_last = slot.t_last_used; - ret = &slot; - } - } - - if (ret != nullptr) { - LOG_VERBOSE("selected slot by lru", { - {"id_slot", ret->id}, - {"t_last", t_last}, - }); - } - } - if (ret) { - const auto& tokens = ret->cache_tokens; - float f_keep = 0.0f; - if (!tokens.empty()) { - if (ret->ga_n == 1 && ret->n_discarded_prompt > 0 && task.tokens.size() >= ret->n_ctx) { - f_keep = tokens.get_cached_tokens_similarity(ret->ctx, task.tokens, ret->params.n_keep + add_bos_token, ret->n_discarded_prompt); - } - else { - f_keep = tokens.get_cached_tokens_similarity(ret->ctx,task.tokens, 0, 0); - } - // if we are about to lose a large portion of the existing context - save it in the prompt cache - if (f_keep < cache_ram_similarity) { - update_cache = true; - } - } - update_cache = update_cache && prompt_cache; - // cache prompts only for completion tasks - update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; - - // don't update the cache if the slot's context is above cache_ram_n_min - update_cache = update_cache && tokens.size() >= cache_ram_n_min; - - // TODO: mtmd does not support prompt cache - update_cache = update_cache && (ret->mctx == nullptr); - - LLAMA_LOG_INFO("======== Prompt cache: cache size: %d, n_keep: %d, n_discarded_prompt: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n", - (int)tokens.size(), ret->n_kept_prompt, ret->n_discarded_prompt, cache_ram_n_min, f_keep, cache_ram_similarity); - if (update_cache) { - const int64_t t_start = ggml_time_us(); - LLAMA_LOG_INFO("updating prompt cache\n"); - ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens - ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt; - ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt; - - ret->prompt_save(*prompt_cache); - LLAMA_LOG_INFO("prompt cache save took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); - } - // has prompts saved earlier to load - if (prompt_cache && !prompt_cache->states.empty()) { - const int64_t t_start = ggml_time_us(); - ret->server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens - ret->server_cached_prompt.n_discarded_prompt = ret->n_discarded_prompt; - ret->server_cached_prompt.n_kept_prompt = ret->n_kept_prompt; - - ret->prompt_load(*prompt_cache, task.tokens); - prompt_cache->update(); - - ret->cache_tokens = server_tokens(ret->server_cached_prompt.tokens.get_text_tokens(), false); // recover cache tokens - ret->n_discarded_prompt = ret->server_cached_prompt.n_discarded_prompt; - ret->n_kept_prompt = ret->server_cached_prompt.n_kept_prompt; - - LLAMA_LOG_INFO("prompt cache load took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); - } - } - return ret; - } - - bool launch_slot_with_task(server_slot & slot, server_task & task) { - slot_params default_params; - // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) - llama_sampling_params default_sparams = params.sparams; - auto & data = task.data; - - if (data.count("__oaicompat") != 0) { - slot.oaicompat = true; - slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - } else { - slot.oaicompat = false; - slot.oaicompat_model = ""; - } - slot.params.timings_per_token = json_value(data, "timings_per_token", false); - slot.params.stream = json_value(data, "stream", false); - auto stream_opt = json_value(data, "stream_options", json::object()); - slot.params.include_usage = json_value(stream_opt, "include_usage", false); - slot.params.cache_prompt = json_value(data, "cache_prompt", true); - slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict)); - slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); - slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); - slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); - slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); - slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); - slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); - slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); - slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability); - slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold); - slot.sparams.top_n_sigma = json_value(data, "top_n_sigma", default_sparams.top_n_sigma); - slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); - slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); - slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); - slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); - slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier); - slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base); - slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length); - slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n); - slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); - slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); - slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); - slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); - slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); - slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); - slot.sparams.seed = json_value(data, "seed", default_sparams.seed); - slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); - slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); - - slot.params.post_sampling_probs = json_value(data, "post_sampling_probs", default_params.post_sampling_probs); - - // speculative decoding parameters - slot.params.speculative.n_max = json_value(data, "speculative.n_max", params.n_draft); - slot.params.speculative.n_min = json_value(data, "speculative.n_min", params.n_draft_min); - slot.params.speculative.p_min = json_value(data, "speculative.p_min", params.p_draft_min); - - // Clamp speculative parameters - slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min); - slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0); - slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0); - - if (slot.sparams.penalty_last_n < -1) { - throw std::runtime_error("Error: repeat_last_n must be >= -1"); - } - - if (slot.sparams.dry_penalty_last_n < -1) { - throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); - } - - if (slot.sparams.penalty_last_n == -1) { - // note: should be the slot's context and not the full context, but it's ok - slot.sparams.penalty_last_n = llama_n_ctx(ctx); - } - - if (slot.sparams.dry_penalty_last_n == -1) { - slot.sparams.dry_penalty_last_n = llama_n_ctx(ctx); - - } - if (slot.sparams.dry_base < 1.0f) - { - slot.sparams.dry_base = default_sparams.dry_base; - } - - // sequence breakers for DRY - { - // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format - // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 - - if (data.contains("dry_sequence_breakers")) { - slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); - if (slot.sparams.dry_sequence_breakers.empty()) { - send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST); - return false; - } - } - } - - // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.contains("grammar")) { - try { - auto schema = json_value(data, "json_schema", json::object()); - LLAMA_LOG_DEBUG("JSON schema: %s\n", schema.dump(2).c_str()); - slot.sparams.grammar = json_schema_to_grammar(schema); - LLAMA_LOG_DEBUG("Converted grammar: %s\n", slot.sparams.grammar.c_str()); - } - catch (const std::exception& e) { - throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); - } - } - else { - slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar); - LLAMA_LOG_DEBUG("Grammar: %s\n", slot.sparams.grammar.c_str()); - slot.sparams.grammar_lazy = json_value(data, "grammar_lazy", default_sparams.grammar_lazy); - LLAMA_LOG_DEBUG("Grammar lazy: %s\n", slot.sparams.grammar_lazy ? "true" : "false"); - } - - if (slot.params.cache_prompt && slot.ga_n != 1) { - LOG_WARNING("cache_prompt is not supported with group-attention", {}); - slot.params.cache_prompt = false; - } - - if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { - // Might be better to reject the request with a 400 ? - LOG_WARNING("Max tokens to predict exceeds server configuration", { - {"params.n_predict", slot.params.n_predict}, - {"slot.n_predict", slot.n_predict}, - }); - slot.params.n_predict = slot.n_predict; - } - - // infill - slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix); - slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix); - - // get prompt - if (!task.infill) { - // maybe not needed since prompt has been tokenized? - const auto & prompt = data.find("prompt"); - if (!slot.prompt_tokens.validate(ctx)) { - send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); - return false; - } - if (prompt == data.end()) { - send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST); - return false; - } - - if ((prompt->is_string()) || - (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) || - (prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) { - slot.prompt = *prompt; - } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) { - slot.prompt = prompt->at(0); - } else { - send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST); - return false; - } - slot.prompt_tokens = std::move(task.tokens); - } - - // penalize user-provided tokens - { - slot.sparams.penalty_prompt_tokens.clear(); - slot.sparams.use_penalty_prompt_tokens = false; - - const auto & penalty_prompt = data.find("penalty_prompt"); - - if (penalty_prompt != data.end()) { - if (penalty_prompt->is_string()) { - const auto penalty_prompt_string = penalty_prompt->get(); - slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false); - - if (slot.params.n_predict > 0) { - slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict); - } - slot.sparams.use_penalty_prompt_tokens = true; - - LOG_VERBOSE("penalty_prompt_tokens", { - {"id_slot", slot.id}, - {"tokens", slot.sparams.penalty_prompt_tokens}, - }); - } - else if (penalty_prompt->is_array()) { - const auto n_tokens = penalty_prompt->size(); - slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict)); - - const int n_vocab = llama_n_vocab(model); - for (const auto & penalty_token : *penalty_prompt) { - if (penalty_token.is_number_integer()) { - const auto tok = penalty_token.get(); - if (tok >= 0 && tok < n_vocab) { - slot.sparams.penalty_prompt_tokens.push_back(tok); - } - } - } - slot.sparams.use_penalty_prompt_tokens = true; - - LOG_VERBOSE("penalty_prompt_tokens", { - {"id_slot", slot.id}, - {"tokens", slot.sparams.penalty_prompt_tokens}, - }); - } - } - } - { - auto it = data.find("chat_format"); - if (it != data.end()) { - slot.params.oaicompat_chat_syntax.format = static_cast(it->get()); - LLAMA_LOG_DEBUG("Chat format: %s\n", common_chat_format_name(slot.params.oaicompat_chat_syntax.format)); - } - else { - slot.params.oaicompat_chat_syntax.format = default_params.oaicompat_chat_syntax.format; - } - common_reasoning_format reasoning_format = params.reasoning_format; - if (data.contains("reasoning_format")) { - reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get()); - } - slot.params.oaicompat_chat_syntax.reasoning_format = reasoning_format; - slot.params.oaicompat_chat_syntax.reasoning_in_content = slot.params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); - slot.params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); - - slot.params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); - } - { - - const auto preserved_tokens = data.find("preserved_tokens"); - if (preserved_tokens != data.end()) { - for (const auto& t : *preserved_tokens) { - auto ids = llama_tokenize(model, t.get(), /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - LOG("Preserved token: %d\n", ids[0]); - slot.sparams.preserved_tokens.insert(ids[0]); - } - else { - // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. - LOG("Not preserved because more than 1 token: %s\n", t.get().c_str()); - } - } - } - const auto grammar_triggers = data.find("grammar_triggers"); - if (grammar_triggers != data.end()) { - for (const auto& t : *grammar_triggers) { - server_grammar_trigger ct(t); - if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { - const auto& word = ct.value.value; - auto ids = llama_tokenize(model, word, /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - auto token = ids[0]; - if (std::find(slot.sparams.preserved_tokens.begin(), slot.sparams.preserved_tokens.end(), (llama_token)token) == slot.sparams.preserved_tokens.end()) { - throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); - } - LOG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); - common_grammar_trigger trigger; - trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; - trigger.value = word; - trigger.token = token; - slot.sparams.grammar_triggers.push_back(std::move(trigger)); - } - else { - LOG("Grammar trigger word: `%s`\n", word.c_str()); - slot.sparams.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word }); - } - } - else { - //slot.sparams.grammar_triggers.push_back(ct); - if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) { - LLAMA_LOG_DEBUG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str()); - } - else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) { - LLAMA_LOG_DEBUG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str()); - } - else { - throw std::runtime_error("Unknown grammar trigger type"); - } - slot.sparams.grammar_triggers.emplace_back(std::move(ct.value)); - } - } - } - - if (slot.sparams.grammar_lazy && slot.sparams.grammar_triggers.empty()) { - throw std::runtime_error("Error: no triggers set for lazy grammar!"); - } - } - - { - slot.sparams.logit_bias.clear(); - - if (json_value(data, "ignore_eos", false) && has_eos_token) { - slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; - } - - const auto & logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) { - const int n_vocab = llama_n_vocab(model); - for (const auto & el : *logit_bias) { - // TODO: we may want to throw errors here, in case "el" is incorrect - if (el.is_array() && el.size() == 2) { - float bias; - if (el[1].is_number()) { - bias = el[1].get(); - } else if (el[1].is_boolean() && !el[1].get()) { - bias = -INFINITY; - } else { - continue; - } - - if (el[0].is_number_integer()) { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) { - slot.sparams.logit_bias[tok] = bias; - } - } else if (el[0].is_string()) { - auto toks = llama_tokenize(model, el[0].get(), false); - for (auto tok : toks) { - slot.sparams.logit_bias[tok] = bias; - } - } - } - } - } - } - - { - slot.params.antiprompt.clear(); - - const auto & stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) { - for (const auto & word : *stop) { - if (!word.empty()) { - slot.params.antiprompt.push_back(word); - } - } - } - } - - { - const auto samplers = data.find("samplers"); - if (samplers != data.end()) { - if (samplers->is_array()) { - slot.sparams.samplers_sequence = llama_sampling_types_from_names(*samplers, false); - } - else if (samplers->is_string()) { - slot.sparams.samplers_sequence = llama_sampling_types_from_chars(samplers->get()); - } - else { - slot.sparams.samplers_sequence = default_sparams.samplers_sequence; - } - } - } - - { - if (slot.ctx_sampling != nullptr) { - llama_sampling_free(slot.ctx_sampling); - } - slot.ctx_sampling = llama_sampling_init(llama_get_model_vocab(model),slot.sparams); - if (slot.ctx_sampling == nullptr) { - // for now, the only error that may happen here is invalid grammar - send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); - return false; - } - } - - slot.command = SLOT_COMMAND_LOAD_PROMPT; - // slot.prompt_tokens.clear(); - - LOG_INFO("slot is processing task", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - }); - - return true; - } - - void kv_cache_clear() { - LOG_VERBOSE("clearing KV cache", {}); - - // clear the entire KV cache - llama_kv_cache_clear(ctx); - clean_kv_cache = false; - } - - void system_prompt_update() { - LOG_VERBOSE("system prompt update", { - {"system_prompt", system_prompt}, - }); - - kv_cache_clear(); - system_tokens.clear(); - - if (!system_prompt.empty()) { - system_tokens = ::llama_tokenize(ctx, system_prompt, true); - - const int32_t n_batch = llama_n_batch(ctx); - const int32_t n_tokens_prompt = system_tokens.size(); - - for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i); - - llama_batch_clear(batch); - - for (int32_t j = 0; j < n_tokens; ++j) { - llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false); - } - - if (llama_decode(ctx, batch) != 0) { - LOG_ERROR("llama_decode() failed", {}); - return; - } - } - - // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i <= params.n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); - } - } - - system_need_update = false; - } - - bool system_prompt_set(const std::string & sys_prompt) { - system_prompt = sys_prompt; - - LOG_VERBOSE("system prompt process", { - {"system_prompt", system_prompt}, - }); - - // release all slots - for (server_slot & slot : slots) { - slot.release(); - } - - system_need_update = true; - return true; - } - - bool process_token(completion_token_output & result, server_slot & slot) { - // remember which tokens were sampled - used for repetition penalties during sampling - const std::string token_str = result.text_to_send; - slot.sampled = result.tok; - - // search stop word and delete it - slot.generated_text += token_str; - slot.has_next_token = true; - - if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) { - // we can change penalty_prompt_tokens because it is always created from scratch each request - slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); - } - - // check if there is incomplete UTF-8 character at the end - bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); - - if (!incomplete) { - size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); - - const std::string str_test = slot.generated_text.substr(pos); - bool send_text = true; - - size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); - if (stop_pos != std::string::npos) { - slot.generated_text.erase( - slot.generated_text.begin() + pos + stop_pos, - slot.generated_text.end()); - pos = std::min(slot.n_sent_text, slot.generated_text.size()); - } - else if (slot.has_next_token && !llama_token_is_eog(model, result.tok)) { - stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); - send_text = stop_pos == std::string::npos; - } - - // check if there is any token to predict - if (send_text) { - // no send the stop word in the response - result.text_to_send = slot.generated_text.substr(pos, std::string::npos); - slot.n_sent_text += result.text_to_send.size(); - // add the token to slot queue and cache - } else { - result.text_to_send = ""; - } - - slot.add_token_string(result); - if (slot.params.stream) { - send_partial_response(slot, result); - } - } - - if (incomplete) { - slot.has_next_token = true; - } - - // check the limits - if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) { - slot.stopped_limit = true; - slot.has_next_token = false; - - LOG_VERBOSE("stopped by limit", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_decoded", slot.n_decoded}, - {"n_predict", slot.params.n_predict}, - }); - } - - if (llama_token_is_eog(model, result.tok)) { - slot.stopped_eos = true; - slot.has_next_token = false; - - LOG_VERBOSE("eos token found", {}); - } - - auto n_ctx_train = llama_n_ctx_train(model); - if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 - && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { - LOG_WARNING("n_predict is not set and self-context extend is disabled." - " Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", { - { "id_slot", slot.id }, - { "params.n_predict", slot.params.n_predict }, - { "slot.n_prompt_tokens", slot.n_prompt_tokens }, - { "slot.n_decoded", slot.n_decoded }, - { "slot.n_predict", slot.n_predict }, - { "n_slots", params.n_parallel }, - { "slot.n_ctx", slot.n_ctx }, - { "n_ctx", n_ctx }, - { "n_ctx_train", n_ctx_train }, - { "ga_n", slot.ga_n }, - }); - slot.truncated = true; - slot.stopped_limit = true; - slot.has_next_token = false; // stop prediction - } - - LOG_VERBOSE("next token", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"token", result.tok}, - {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, - {"has_next_token", slot.has_next_token}, - {"n_remain", slot.n_remaining}, - {"n_decoded", slot.n_decoded}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - }); - - return slot.has_next_token; // continue - } - - void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { - size_t n_probs = slot.sparams.n_probs; - size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); - - if (post_sampling) { - const auto * cur_p = llama_sampling_get_candidates(slot.ctx_sampling); - const size_t max_probs = cur_p->size; - - // set probability for sampled token - for (size_t i = 0; i < max_probs; i++) { - if (cur_p->data[i].id == result.tok) { - result.prob = cur_p->data[i].p; - break; - } - } - - // set probability for top n_probs tokens - result.probs.reserve(max_probs); - for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { - result.probs.push_back({ - cur_p->data[i].id, - llama_detokenize(ctx, {cur_p->data[i].id}, special), - cur_p->data[i].p - }); - } - } else { - auto&&[sampled_token_p, cur] = get_token_probabilities(ctx, idx, result.tok, n_probs); - - // set probability for sampled token - result.prob = sampled_token_p; - - // set probability for top n_probs tokens - result.probs.reserve(n_probs); - for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { - result.probs.push_back({ - cur[i].id, - llama_detokenize(ctx, {cur[i].id}, special), - cur[i].p - }); - } - } - } - - json get_formated_generation(const server_slot & slot) const { - const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); - const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); - - std::vector samplers_sequence; - samplers_sequence.reserve(slot.sparams.samplers_sequence.size()); - for (const auto & sampler_type : slot.sparams.samplers_sequence) { - samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type)); - } - - auto grammar_triggers = json::array(); - for (const auto& trigger : slot.sparams.grammar_triggers) { - grammar_triggers.push_back(trigger.to_json()); - } - - return json { - {"n_ctx", slot.n_ctx}, - {"n_predict", slot.n_predict}, // Server configured n_predict - {"model", params.model_alias}, - {"seed", slot.sparams.seed}, - {"temperature", slot.sparams.temp}, - {"dynatemp_range", slot.sparams.dynatemp_range}, - {"dynatemp_exponent", slot.sparams.dynatemp_exponent}, - {"top_k", slot.sparams.top_k}, - {"top_p", slot.sparams.top_p}, - {"min_p", slot.sparams.min_p}, - {"tfs_z", slot.sparams.tfs_z}, - {"typical_p", slot.sparams.typical_p}, - {"repeat_last_n", slot.sparams.penalty_last_n}, - {"repeat_penalty", slot.sparams.penalty_repeat}, - {"presence_penalty", slot.sparams.penalty_present}, - {"frequency_penalty", slot.sparams.penalty_freq}, - {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, - {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, - {"dry_multiplier", slot.sparams.dry_multiplier}, - {"dry_base", slot.sparams.dry_base}, - {"dry_allowed_length", slot.sparams.dry_allowed_length}, - {"dry_penalty_last_n", slot.sparams.dry_penalty_last_n}, - {"dry_sequence_breakers", slot.sparams.dry_sequence_breakers}, - {"mirostat", slot.sparams.mirostat}, - {"mirostat_tau", slot.sparams.mirostat_tau}, - {"mirostat_eta", slot.sparams.mirostat_eta}, - {"penalize_nl", slot.sparams.penalize_nl}, - {"stop", slot.params.antiprompt}, - {"max_tokens", slot.params.n_predict}, // User configured n_predict - {"n_keep", slot.params.n_keep}, - {"n_discard", slot.params.n_discard}, - {"ignore_eos", ignore_eos}, - {"stream", slot.params.stream}, - {"logit_bias", slot.sparams.logit_bias}, - {"n_probs", slot.sparams.n_probs}, - {"min_keep", slot.sparams.min_keep}, - {"grammar", slot.sparams.grammar}, - {"grammar_triggers", grammar_triggers}, - {"preserved_tokens", slot.sparams.preserved_tokens}, - {"chat_format", common_chat_format_name(slot.params.oaicompat_chat_syntax.format)}, - {"reasoning_format", common_reasoning_format_name(slot.params.oaicompat_chat_syntax.reasoning_format)}, - {"reasoning_in_content", slot.params.oaicompat_chat_syntax.reasoning_in_content}, - {"thinking_forced_open", slot.params.oaicompat_chat_syntax.thinking_forced_open}, - {"samplers", samplers_sequence} - }; - } - - void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(task.id, task.id_multi, error, type); - } - - void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(slot.id_task, slot.id_multi, error, type); - } - - void send_error(const int id_task, const int id_multi, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - LOG_ERROR("task error", { - {"id_multi", id_multi}, - {"id_task", id_task}, - {"error", error}, - }); - - server_task_result res; - res.id = id_task; - res.id_multi = id_multi; - res.stop = false; - res.error = true; - res.data = format_error_response(error, type); - - queue_results.send(res); - } - - // if multimodal is enabled, send an error and return false - bool ensure_no_mtmd(const int id_task) { - if (mctx) { - int id_multi = 0; - send_error(id_task, id_multi, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); - return false; - } - return true; - } - - void send_partial_response(server_slot & slot, completion_token_output tkn) { - server_task_result res; - res.final_result = false; - res.id = slot.id_task; - res.id_multi = slot.id_multi; - res.error = false; - res.stop = false; - res.stream = slot.params.stream; - res.content = tkn.text_to_send; - res.post_sampling_probs = slot.params.post_sampling_probs; - res.oaicompat = slot.params.oaicompat; - res.oaicompat_model = slot.params.oaicompat_model; - res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - res.n_decoded = slot.n_decoded; - res.n_prompt_tokens = slot.n_prompt_tokens; - res.data = json { - {"content", tkn.text_to_send}, - {"stop", false}, - {"id_slot", slot.id}, - {"multimodal", false} - }; - slot.update_chat_msg(res.oaicompat_msg_diffs); - - // populate res.probs_output - if (slot.sparams.n_probs > 0) { - res.probs_output = {tkn}; // copy the token probs - res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output); - } - - if (slot.oaicompat) { - res.data["oaicompat_token_ctr"] = slot.n_decoded; - res.data["model"] = slot.oaicompat_model; - } - - // populate timings if this is final response or timings_per_token is enabled - if (slot.params.timings_per_token) { - res.timings = slot.get_timings(); - } - queue_results.send(std::move(res)); - } - - void send_final_response(server_slot& slot) { - server_task_result res; - res.final_result = true; - res.id = slot.id_task; - res.id_multi = slot.id_multi; - res.error = false; - res.stop = true; // to do: set value - res.stream = slot.params.stream; - res.include_usage = slot.params.include_usage; - res.content = slot.generated_text; - res.timings = slot.get_timings(); - res.post_sampling_probs = slot.params.post_sampling_probs; - res.oaicompat = slot.params.oaicompat; - res.oaicompat_model = slot.params.oaicompat_model; - res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - res.oaicompat_msg = slot.update_chat_msg(res.oaicompat_msg_diffs); - res.n_decoded = slot.n_decoded; - res.n_prompt_tokens = slot.n_prompt_tokens; - res.oaicompat_model = slot.oaicompat_model; - res.data = json { - {"content", !slot.params.stream ? slot.generated_text : ""}, - {"generated_text", slot.generated_text}, // Always include full text for finish_reason logic - {"id_slot", slot.id}, - {"stop", true}, - {"model", params.model_alias}, - {"tokens_predicted", slot.n_decoded}, - {"tokens_evaluated", slot.n_prompt_tokens}, - {"generation_settings", get_formated_generation(slot)}, - {"prompt", slot.prompt}, - {"truncated", slot.truncated}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - {"tokens_cached", slot.n_past}, - {"timings", slot.get_formated_timings()}, - //{"oaicompat_chat_format", slot.params.oaicompat_chat_format}, - }; - - // populate res.probs_output - if (slot.sparams.n_probs > 0) { - res.probs_output = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end()); - res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output); - } - - if (slot.oaicompat) { - res.data["oaicompat_token_ctr"] = slot.n_decoded; - res.data["model"] = slot.oaicompat_model; - } - - queue_results.send(std::move(res)); - } - - void send_embedding(const server_slot & slot, const llama_batch & batch) { - server_task_result res; - res.id = slot.id_task; - res.id_multi = slot.id_multi; - res.error = false; - res.stop = true; - - const int n_embd = llama_n_embd(model); - - std::vector embd_res(n_embd, 0.0f); - - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } - - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } - - if (embd == NULL) { - LOG_ERROR("failed to get embeddings", { - {"token", batch.token [i]}, - {"seq_id", batch.seq_id[i][0]} - }); - - res.data = json { - {"embedding", std::vector(n_embd, 0.0f)}, - {"tokens_evaluated", slot.n_prompt_tokens}, - }; - - continue; - } - - llama_embd_normalize(embd, embd_res.data(), n_embd); - - res.data = json { - {"embedding", embd_res}, - {"tokens_evaluated", slot.n_prompt_tokens}, - }; - } - - queue_results.send(res); - } - - void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens && inputs) { - server_task task; - task.id = id_task; - task.id_multi = id_multi; - task.id_target = 0; - task.data = std::move(data); - task.infill = infill; - task.embedding = embedding; - task.type = SERVER_TASK_TYPE_COMPLETION; - task.tokens = std::move(inputs); - // when a completion task's prompt array is not a singleton, we split it into multiple requests - // otherwise, it's a single-prompt task, we actually queue it - // if there's numbers in the prompt array it will be treated as an array of tokens - if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) { - bool numbers = false; - for (const auto & e : task.data.at("prompt")) { - if (e.is_number()) { - numbers = true; - break; - } - } - - // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers, - // it will completely stall the server. I don't know where the bug for this is. - // - // if there are numbers, it needs to be treated like a single prompt, - // queue_tasks handles a mix of strings and numbers just fine. - if (numbers) { - queue_tasks.post(std::move(task)); - } else { - split_multiprompt_task(id_task, task); - } - } else { - queue_tasks.post(std::move(task)); - } - } - - void request_cancel(int id_task) { - server_task task; - task.type = SERVER_TASK_TYPE_CANCEL; - task.id_target = id_task; - - queue_tasks.post(std::move(task)); - } - - void split_multiprompt_task(int id_multi, server_task & multiprompt_task) { - const int prompt_count = multiprompt_task.data.at("prompt").size(); - if (prompt_count <= 1) { - send_error(multiprompt_task, "error while handling multiple prompts"); - return; - } - - // generate all the ID for subtask - std::vector subtask_ids(prompt_count); - for (int i = 0; i < prompt_count; i++) { - subtask_ids[i] = queue_tasks.get_new_id(); - } - - // queue up the multitask so we can track its subtask progression - queue_tasks.add_multitask(id_multi, subtask_ids); - - // add subtasks - for (int i = 0; i < prompt_count; i++) { - json subtask_data = multiprompt_task.data; - subtask_data["prompt"] = subtask_data.at("prompt")[i]; - - // subtasks inherit everything else (infill mode, embedding mode, etc.) - request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding, - std::move(multiprompt_task.tokens)); - } - } - - void process_single_task(server_task && task) { - switch (task.type) { - case SERVER_TASK_TYPE_COMPLETION: - { - const int id_slot = json_value(task.data, "id_slot", -1); - - server_slot * slot; - - if (id_slot != -1) { - slot = get_slot_by_id(id_slot); - } else { - slot = get_available_slot(task); - } - - if (slot == nullptr) { - // if no slot is available, we defer this task for processing later - LOG_VERBOSE("no slot is available", {{"id_task", task.id}}); - queue_tasks.defer(std::move(task)); - break; - } - if (!slot->available()) { - // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); - queue_tasks.defer(std::move(task)); - break; - } - - if (task.data.contains("system_prompt")) { - std::string sys_prompt = json_value(task.data, "system_prompt", std::string()); - system_prompt_set(sys_prompt); - - for (server_slot & slot : slots) { - slot.n_past = 0; - slot.n_past_se = 0; - } - } - - slot->reset(); - - slot->id_task = task.id; - slot->id_multi = task.id_multi; - slot->infill = task.infill; - slot->embedding = task.embedding; - - if (!launch_slot_with_task(*slot, task)) { - LOG_ERROR("error while launching slot", task.data); - break; - } - } break; - case SERVER_TASK_TYPE_CANCEL: - { - // release slot linked with the task id - for (auto & slot : slots) { - if (slot.id_task == task.id_target) { - slot.release(); - break; - } - } - } break; - case SERVER_TASK_TYPE_NEXT_RESPONSE: - { - // do nothing - } break; - case SERVER_TASK_TYPE_METRICS: - { - json slots_data = json::array(); - - int n_idle_slots = 0; - int n_processing_slots = 0; - - for (server_slot & slot : slots) { - json slot_data = get_formated_generation(slot); - slot_data["id"] = slot.id; - slot_data["id_task"] = slot.id_task; - slot_data["state"] = slot.state; - slot_data["prompt"] = slot.prompt; - slot_data["next_token"] = { - {"has_next_token", slot.has_next_token}, - {"n_remain", slot.n_remaining}, - {"n_decoded", slot.n_decoded}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - }; - - if (slot_data["state"] == SLOT_STATE_IDLE) { - n_idle_slots++; - } else { - n_processing_slots++; - } - - slots_data.push_back(slot_data); - } - LOG_INFO("slot data", { - {"id_task", task.id}, - {"n_idle_slots", n_idle_slots}, - {"n_processing_slots", n_processing_slots} - }); - - LOG_VERBOSE("slot data", { - {"id_task", task.id}, - {"n_idle_slots", n_idle_slots}, - {"n_processing_slots", n_processing_slots}, - {"slots", slots_data} - }); - - server_task_result res; - res.id = task.id; - res.id_multi = task.id_multi; - res.stop = true; - res.error = false; - res.data = { - { "idle", n_idle_slots }, - { "processing", n_processing_slots }, - { "deferred", queue_tasks.queue_tasks_deferred.size() }, - { "t_start", metrics.t_start}, - - { "n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total}, - { "t_tokens_generation_total", metrics.t_tokens_generation_total}, - { "n_tokens_predicted_total", metrics.n_tokens_predicted_total}, - { "t_prompt_processing_total", metrics.t_prompt_processing_total}, - - { "n_prompt_tokens_processed", metrics.n_prompt_tokens_processed}, - { "t_prompt_processing", metrics.t_prompt_processing}, - { "n_tokens_predicted", metrics.n_tokens_predicted}, - { "t_tokens_generation", metrics.t_tokens_generation}, - - { "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)}, - { "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)}, - - { "slots", slots_data }, - }; - - if (json_value(task.data, "reset_bucket", false)) { - metrics.reset_bucket(); - } - queue_results.send(res); - } break; - case SERVER_TASK_TYPE_SLOT_SAVE: - { - if (!ensure_no_mtmd(task.id)) { - break; - } - int id_slot = task.data.at("id_slot"); - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (!slot->available()) { - // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); - queue_tasks.defer(std::move(task)); - break; - } - - const size_t token_count = slot->cache_tokens.size(); - const int64_t t_start = ggml_time_us(); - - std::string filename = task.data.at("filename"); - std::string filepath = task.data.at("filepath"); - - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); - - const int64_t t_end = ggml_time_us(); - const double t_save_ms = (t_end - t_start) / 1000.0; - - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_saved", token_count }, // tokens saved - { "n_written", nwrite }, // bytes written - { "timings", { - { "save_ms", t_save_ms } - } } - }; - queue_results.send(result); - } break; - case SERVER_TASK_TYPE_SLOT_RESTORE: - { - if (!ensure_no_mtmd(task.id)) break; - int id_slot = task.data.at("id_slot"); - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (!slot->available()) { - // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); - queue_tasks.defer(std::move(task)); - break; - } - - const int64_t t_start = ggml_time_us(); - - std::string filename = task.data.at("filename"); - std::string filepath = task.data.at("filepath"); - - slot->cache_tokens.resize(slot->n_ctx); - size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); - if (nread == 0) { - slot->cache_tokens.resize(0); - send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); - break; - } - slot->cache_tokens.resize(token_count); - - const int64_t t_end = ggml_time_us(); - const double t_restore_ms = (t_end - t_start) / 1000.0; - - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_restored", token_count }, // tokens restored - { "n_read", nread }, // bytes read - { "timings", { - { "restore_ms", t_restore_ms } - } } - }; - queue_results.send(result); - } break; - case SERVER_TASK_TYPE_SLOT_ERASE: - { - if (!ensure_no_mtmd(task.id)) break; - int id_slot = task.data.at("id_slot"); - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (!slot->available()) { - // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); - queue_tasks.defer(std::move(task)); - break; - } - - // Erase token cache - const size_t n_erased = slot->cache_tokens.size(); - llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); - slot->cache_tokens.clear(); - - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json { - { "id_slot", id_slot }, - { "n_erased", n_erased } - }; - queue_results.send(result); - } break; - case SERVER_TASK_TYPE_SET_LORA: - { - llama_lora_adapters_apply(ctx, lora_adapters); - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json{{ "success", true }}; - queue_results.send(result); - } break; - } - } - - void on_finish_multitask(const server_task_multi & multitask) { - // all subtasks done == multitask is done - server_task_result result; - result.id = multitask.id; - result.stop = true; - result.error = false; - - // collect json results into one json result - std::vector result_jsons; - for (const auto & subres : multitask.results) { - result_jsons.push_back(subres.data); - result.error = result.error && subres.error; - } - result.data = json { - { "results", result_jsons } - }; - - queue_results.send(result); - } - - void print_tokens(const server_tokens & prompt, const server_tokens& cache, size_t start1 = 0, size_t start2=0 , size_t length = 10) { - if (cache.size() > start2) { - LLAMA_LOG_INFO("cache : %s\n", cache.detokenize(ctx, true, start2, length).c_str()); - } - if (prompt.size()> start1) { - LLAMA_LOG_INFO("prompt: %s\n", prompt.detokenize(ctx, true, start1, length).c_str()); - } - - } - - void discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard) { - llama_kv_cache_seq_rm(ctx, slot.id, n_keep, n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); - if (slot.params.cache_prompt) { - slot.cache_tokens.discard_n_tokens(n_keep, n_discard); - } - } - - // convert keep first few and discard next tokens in a to b - void context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep, - int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact = false) { - - common_prefix ctx_keep_prefix = a.get_common_prefix_first_n(ctx, b, n_keep, exact); - common_prefix ctx_total_discard_prefix = a.get_common_prefix_first_n(ctx, b, n_discard + n_keep, exact); - // only if there is enough common token - int32_t discard_offset = ctx_total_discard_prefix.first - (n_discard + n_keep); - int32_t keep_offset = ctx_keep_prefix.first - n_keep; - n_kept = ctx_keep_prefix.second - keep_offset; - n_discarded = ctx_total_discard_prefix.second - ctx_keep_prefix.second - discard_offset; - if (n_kept < 0) { - n_kept = n_keep; - } - if (n_discarded < 0) { - n_discarded = n_discard; - } - } - - void context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact = false) { - //server_tokens prompt_tokens = std::move(slot.prompt_tokens); - int n_keep = std::max(0, slot.params.n_keep + add_bos_token); - const int n_left = slot.n_ctx - n_keep; - int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); - - int n_discard_prompt = 0; - // we still need to truncate input since we have not discarded enough tokens - while (slot.n_prompt_tokens - slot.n_discarded_prompt >= slot.n_ctx) { - slot.n_discarded_prompt = slot.n_discarded_prompt + n_discard; - n_discard_prompt = n_discard_prompt + n_discard; - } - - // Handle mistokenization between prompt and cache during context shift - // - int32_t n_discard_cache = n_discard_prompt; - int32_t n_kept = n_keep; - slot.prompt_tokens.discard_n_tokens(n_keep, slot.n_discarded_prompt - n_discard_prompt); - if (n_discard_prompt > 0) { - context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep, - n_discard, n_kept, n_discard_cache, exact); - } - - int n_discard_cache_max = std::max((int32_t)slot.cache_tokens.size() - n_kept, 0); - n_discard_cache = std::min(n_discard_cache, n_discard_cache_max); - // discard matching tokens from cache and kv cache to avoid reprocessing the prompt - if (n_discard_cache > 0) { - discard_n_kv_and_cache_tokens(ctx, slot, n_kept, n_discard_cache); - } - // discard extra tokens from prompts - slot.n_kept_prompt = n_keep; - slot.prompt_tokens.discard_n_tokens(n_keep, n_discard_prompt); - slot.n_prompt_tokens = slot.prompt_tokens.size(); - } - - void update_slots() { - if (system_need_update) { - system_prompt_update(); - } - - // release slots - for (auto & slot : slots) { - if (slot.command == SLOT_COMMAND_RELEASE) { - slot.state = SLOT_STATE_IDLE; - slot.command = SLOT_COMMAND_NONE; - slot.t_last_used = ggml_time_us(); - - LOG_INFO("slot released", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", n_ctx}, - {"n_past", slot.n_past}, - {"n_system_tokens", system_tokens.size()}, - {"n_cache_tokens", slot.cache_tokens.size()}, - {"truncated", slot.truncated} - }); - - queue_tasks.notify_slot_changed(); - } - } - - // check if all slots are idle - { - bool all_idle = true; - - for (auto & slot : slots) { - if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) { - all_idle = false; - break; - } - } - - if (all_idle) { - LOG_INFO("all slots are idle", {}); - if (system_prompt.empty() && clean_kv_cache) { - kv_cache_clear(); - } - - return; - } - } - - { - LOG_VERBOSE("posting NEXT_RESPONSE", {}); - - server_task task; - task.type = SERVER_TASK_TYPE_NEXT_RESPONSE; - task.id_target = -1; - - queue_tasks.post(std::move(task)); - } - - // apply context-shift if needed - // TODO: simplify and improve - for (server_slot & slot : slots) { - if (slot.ga_n == 1) { - if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) { - if (!params.ctx_shift) { - // this check is redundant (for good) - // we should never get here, because generation should already stopped in process_token() - send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - if (mctx) { - // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded - // we don't support ctx_shift because an image chunk may contains multiple tokens - GGML_ABORT("not supported by multimodal"); - } - // Shift context - int n_keep = slot.params.n_keep < 0 ? slot.prompt_tokens.size() : slot.params.n_keep; - if (add_bos_token) { - n_keep += 1; - } - n_keep = std::min(slot.n_ctx - 4, n_keep); - - const int n_left = (int) system_tokens.size() + slot.n_past - n_keep; - const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); - int32_t n_kept; - int32_t n_discard_cache; - if (n_discard > 0) { - context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep, - n_discard, n_kept, n_discard_cache); - LOG_INFO("slot context shift", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_keep", n_keep}, - {"n_left", n_left}, - {"n_discard", n_discard}, - {"n_ctx", n_ctx}, - {"n_past", slot.n_past}, - {"n_system_tokens", system_tokens.size()}, - {"n_cache_tokens", slot.cache_tokens.size()} - }); - slot.n_discarded_prompt = slot.n_discarded_prompt + n_discard; - slot.n_kept_prompt = n_keep; - discard_n_kv_and_cache_tokens(ctx, slot, n_kept, n_discard_cache); - slot.n_past -= n_discard_cache; - slot.truncated = true; - } - - } - } - } - - // start populating the batch for this iteration - llama_batch_clear(batch); - - auto accept_special_token = [&](server_slot& slot, llama_token token) { - return params.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end(); - }; - - // frist, add sampled tokens from any ongoing sequences - for (auto & slot : slots) { - if (slot.state == SLOT_STATE_IDLE) { - continue; - } - - slot.i_batch = batch.n_tokens; - - const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; - - // TODO: we always have to take into account the "system_tokens" - // this is not great and needs to be improved somehow - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.cache_tokens.pos_next(), { slot.id }, true); - - slot.n_past += 1; - - if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(slot.sampled); - } - - LOG_VERBOSE("slot decode token", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", n_ctx}, - {"n_past", slot.n_past}, - {"n_system_tokens", system_tokens.size()}, - {"n_cache_tokens", slot.cache_tokens.size()}, - {"truncated", slot.truncated} - }); - } - - // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx); - int32_t n_ubatch = llama_n_ubatch(ctx); - - // track if this is an embedding or non-embedding batch - // if we've added sampled tokens above, we are in non-embedding mode - // -1: none, 0: non-embedding, 1: embedding - int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; - - // next, batch any pending prompts without exceeding n_batch - if (params.cont_batching || batch.n_tokens == 0) { - for (auto & slot : slots) { - // this slot still has a prompt to be processed - if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) { - auto & prompt_tokens = slot.prompt_tokens; - - // we haven't tokenized the prompt yet - do it now: - if (prompt_tokens.empty() || slot.n_prompt_tokens==0 ) { - LOG_VERBOSE("tokenizing prompt", { - {"id_slot", slot.id}, - {"id_task", slot.id_task} - }); - - slot.t_start_process_prompt = ggml_time_us(); - slot.t_start_generation = 0; - - if (slot.infill) { - const bool add_bos = llama_should_add_bos_token(model); - bool suff_rm_leading_spc = true; - if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { - params.input_suffix.erase(0, 1); - suff_rm_leading_spc = false; - } - - auto prefix_tokens = tokenize(slot.params.input_prefix, false); - auto suffix_tokens = tokenize(slot.params.input_suffix, false); - - const int space_token = 29871; // TODO: this should not be hardcoded - if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) { - suffix_tokens.erase(suffix_tokens.begin()); - } - - prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); - suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model)); - - auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens; - auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens; - if (add_bos) { - embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); - } - embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - - const llama_token middle_token = llama_token_middle(model); - if (middle_token >= 0) { - embd_inp.push_back(middle_token); - } - - prompt_tokens = server_tokens(embd_inp, false); - } else { - // prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt - } - - slot.n_past = 0; - slot.n_prompt_tokens = prompt_tokens.size(); - - LOG_VERBOSE("prompt tokenized", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_prompt_tokens", slot.n_prompt_tokens}, - {"prompt_tokens", prompt_tokens.detokenize(ctx, true)}, - }); - - // empty prompt passed -> release the slot and send empty response - if (prompt_tokens.empty()) { - LOG_INFO("empty prompt - releasing slot", { - {"id_slot", slot.id}, - {"id_task", slot.id_task} - }); - - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; - slot.release(); - slot.print_timings(); - send_final_response(slot); - continue; - } - - if (slot.embedding) { - // this prompt is too large to process - discard it - if (slot.n_prompt_tokens > n_ubatch) { - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; - slot.release(); - send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); - continue; - } - } else { - // if input prompt is too big, truncate it (if group attention self-extend is disabled) - // context shift for prompt processing - if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) { - if (!params.ctx_shift) { - send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - if (mctx) { - // we should never reach this because params.ctx_shift is automatically disabled if mmproj is loaded - // we don't support ctx_shift because an image chunk may contains multiple tokens - GGML_ABORT("not supported by multimodal"); - } - - context_shift_prompt(ctx, slot); - slot.truncated = true; - LOG_VERBOSE("input truncated", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_left", slot.n_ctx- slot.params.n_keep}, - {"n_prompt_tokens", slot.n_prompt_tokens}, - {"prompt_tokens", prompt_tokens.detokenize(ctx, true)}, - }); - - GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); - -#ifndef NDEBUG - // debug - common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false); - int32_t back = 1; - if (slot.cache_tokens.size() && slot.cache_tokens.size() > prefix.first+20 - && prefix.second >= back && prefix.first >= back) { - LLAMA_LOG_INFO("After context shift :\n"); - print_tokens(slot.prompt_tokens, slot.cache_tokens, prefix.second - back, prefix.first - back, 50); - } -#endif - } - else { - slot.n_discarded_prompt = 0; - } - llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling); - - if (!slot.params.cache_prompt) { - slot.n_past_se = 0; - slot.ga_i = 0; - } else { - GGML_ASSERT(slot.ga_n == 1); - - // reuse any previously computed tokens that are common with the new prompt - common_prefix prefix = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, true); // string level match - common_prefix prefix_nonexact = slot.cache_tokens.get_common_prefix(ctx, prompt_tokens, false); - auto n_past0 = slot.cache_tokens.get_common_prefix_exact(prompt_tokens); // token level match - LLAMA_LOG_INFO("======== Cache: cache_size = %d, n_past0 = %d, n_past1 = %d, n_past_prompt1 = %d, n_past2 = %d, n_past_prompt2 = %d\n", (int32_t) slot.cache_tokens.size(), (int32_t) n_past0, (int32_t) prefix.first, (int32_t)prefix.second, (int32_t) prefix_nonexact.first, (int32_t) prefix_nonexact.second); - int32_t size_threshold = 20; - if (prefix.first + size_threshold < prefix_nonexact.first) { - LLAMA_LOG_WARN("Common part contains missing or extra space and new line\n"); - prefix = prefix_nonexact; - } - slot.n_past = prefix.first; - slot.n_past_prompt = prefix.second; - if (slot.n_past != slot.n_past_prompt) { - LLAMA_LOG_INFO("Mistokenization found and handled successfully.\n"); - } - if ((slot.n_past + size_threshold < slot.cache_tokens.size())) - { - LLAMA_LOG_WARN("Common part does not match fully\n"); - int32_t back = 4; - if (prefix.second >= back && prefix.first >= back) { - print_tokens(slot.prompt_tokens, slot.cache_tokens, prefix.second - back, prefix.first - back, 30); - } - } - - // push the prompt into the sampling context (do not apply grammar) - for (int i = 0; i < slot.n_past; ++i) { - llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); - } - } - } - - if (slot.n_past_prompt == slot.n_prompt_tokens && slot.n_past_prompt > 0) { - // we have to evaluate at least 1 token to generate logits. - LOG_INFO("we have to evaluate at least 1 token to generate logits", { - { "id_slot", slot.id }, - { "id_task", slot.id_task } - }); - - slot.n_past_prompt--; - slot.n_past--; - if (slot.ga_i > 0) { - slot.n_past_se--; - } - } - - slot.n_prompt_tokens_processed = 0; - } - - if (slot.embedding) { - // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { - continue; - } - } - - // check that we are in the right batch_type, if not defer the slot - bool slot_type = slot.embedding ? 1 : 0; - if (batch_type == -1) { - batch_type = slot_type; - } else if (batch_type != slot_type) { - continue; - } - - // keep only the common part - // remove the non-common part from the cache - slot.cache_tokens.keep_first(slot.n_past); - int p0 = (int) system_tokens.size() + slot.n_past; - p0 = system_tokens.size() + slot.cache_tokens.pos_next(); - if (!llama_kv_cache_seq_rm(ctx, slot.id, p0, -1)) { - // could not partially delete (likely using a non-Transformer model) - llama_kv_cache_seq_rm(ctx, slot.id, -1, -1); - - p0 = (int) system_tokens.size(); - if (p0 != 0) { - // copy over the system prompt when there is one - llama_kv_cache_seq_cp(ctx, 0, slot.id, -1, -1); - } - - // there is no common part left (except for the system prompt) - slot.n_past = 0; - slot.n_past_se = 0; - slot.ga_i = 0; - // TODO: is the system prompt ever in the sampling context? - llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling); - } - - LOG_INFO("kv cache rm [p0, end)", { - { "id_slot", slot.id }, - { "id_task", slot.id_task }, - { "p0", p0 } - }); - - // check if we should process the image - if (slot.n_past_prompt < slot.n_prompt_tokens - && slot.prompt_tokens[slot.n_past_prompt] == LLAMA_TOKEN_NULL) { - // process the image - size_t n_tokens_out = 0; - llama_pos p1 = slot.cache_tokens.pos_next()+slot.n_past_prompt-slot.n_past; // add offset to prompt - int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past_prompt, p1, slot.id, n_tokens_out); - if (res != 0) { - LLAMA_LOG_ERROR("failed to process image, res = %d\n", res); - slot.release(); - send_error(slot, "failed to process image", ERROR_TYPE_SERVER); - continue; - } - - // add the image chunk to cache - { - const auto& chunk = slot.prompt_tokens.find_chunk(slot.n_past_prompt); - slot.cache_tokens.push_back(chunk.get()); // copy - } - - slot.n_past += n_tokens_out; - slot.n_past_prompt += n_tokens_out; - slot.n_prompt_tokens_processed += n_tokens_out; - - } - - - - int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; - - int32_t ga_i = slot.ga_i; - int32_t ga_n = slot.ga_n; - int32_t ga_w = slot.ga_w; - - // add prompt tokens for processing in the current batch - // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow - while (slot.n_past_prompt < slot.n_prompt_tokens && batch.n_tokens < n_batch) { - // get next token to process - llama_token cur_tok = slot.prompt_tokens[slot.n_past_prompt]; - if (cur_tok == LLAMA_TOKEN_NULL) { - break; // end of text chunk - } - if (slot.ga_n != 1) { - while (slot_npast >= ga_i + ga_w) { - const int bd = (ga_w/ga_n)*(ga_n - 1); - slot_npast -= bd; - ga_i += ga_w/ga_n; - } - } - - int p0=system_tokens.size() + slot.cache_tokens.pos_next(); - llama_batch_add(batch, cur_tok, p0, { slot.id }, false); - - slot.cache_tokens.push_back(cur_tok); - - - slot.n_prompt_tokens_processed++; - slot_npast++; - slot.n_past_prompt++; - slot.n_past++; - } - LOG_VERBOSE("prompt processing progress", { - {"id_slot", slot.id}, - {"n_past", slot.n_past}, - {"n_ctx", n_ctx}, - {"n_tokens", batch.n_tokens}, - {"progress", (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens}, - }); - - // entire prompt has been processed - start decoding new tokens - if (slot.n_past_prompt == slot.n_prompt_tokens) { - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; - - GGML_ASSERT(batch.n_tokens > 0); - GGML_ASSERT((size_t)slot.n_prompt_tokens == slot.prompt_tokens.size()); - llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling); - for (int i = 0; i < slot.n_prompt_tokens; ++i) { - llama_token id = slot.prompt_tokens[i]; - if (id != LLAMA_TOKEN_NULL) { - llama_sampling_accept(slot.ctx_sampling, ctx, id, false); - } - } - - // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; - - slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; - - LOG_VERBOSE("prompt done", { - {"id_slot", slot.id}, - {"n_past", slot.n_past}, - {"n_ctx", n_ctx}, - {"n_tokens", batch.n_tokens}, - }); - } - } - - if (batch.n_tokens >= n_batch) { - break; - } - } - } - - if (batch.n_tokens == 0) { - LOG_VERBOSE("no tokens to decode", {}); - return; - } - - LOG_VERBOSE("decoding batch", { - {"n_tokens", batch.n_tokens}, - }); - - // make sure we're in the right embedding mode - llama_set_embeddings(ctx, batch_type == 1); - - // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - - for (auto & slot : slots) { - if (slot.ga_n != 1) { - // context extension via Self-Extend - // TODO: simplify and/or abstract this - while (slot.n_past_se >= slot.ga_i + slot.ga_w) { - const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; - const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); - const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; - - LOG_TEE("\n"); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd); - LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); - - llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); - llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); - llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); - - slot.n_past_se -= bd; - - slot.ga_i += slot.ga_w / slot.ga_n; - - LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); - } - - slot.n_past_se += n_tokens; - } - } - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - 0, 0, 0, // unused - }; - - const int ret = llama_decode(ctx, batch_view); - - if (ret != 0) { - if (n_batch == 1 || ret < 0) { - // if you get here, it means the KV cache is full - try increasing it via the context size - LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", { - {"i", i}, - {"n_batch", ret}, - {"ret", ret}, - }); - for (auto & slot : slots) { - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; - slot.release(); - LLAMA_LOG_INFO("n_past = %d\n", (int)slot.cache_tokens.size()); - send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); - } - break; // break loop of n_batch - } - - - // retry with half the batch size to try to find a free slot in the KV cache - n_batch /= 2; - i -= n_batch; - - LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", { - {"i", i}, - {"n_batch", n_batch}, - {"ret", ret}, - }); - - continue; // continue loop of n_batch - } - - for (auto & slot : slots) { - if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { - continue; // continue loop of slots - } - - // prompt evaluated for embedding - if (slot.embedding) { - send_embedding(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - completion_token_output result; - const int tok_idx = slot.i_batch - i; - const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, tok_idx); - - llama_sampling_accept(slot.ctx_sampling, ctx, id, true); - - slot.n_decoded += 1; - - const int64_t t_current = ggml_time_us(); - - if (slot.n_decoded == 1) { - slot.t_start_generation = ggml_time_us(); - slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; - metrics.on_prompt_eval(slot); - } - - slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; - - result.tok = id; - result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - result.text_to_send = llama_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - - if (slot.sparams.n_probs > 0) { - populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, tok_idx); - } - - if (!process_token(result, slot)) { - slot.release(); - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - } - - slot.i_batch = -1; - } - - // Do speculative decoding - for (auto & slot : slots) { - if (!slot.is_processing() || !slot.spec) { - continue; - } - - if (slot.state != SLOT_STATE_PROCESSING) { - continue; - } - - if (mctx) { - // we should never reach this, as speculative is automatically disabled if mmproj is loaded - GGML_ABORT("not supported by multimodal"); - } - - // determine the max draft that fits the current slot state - int n_draft_max = slot.params.speculative.n_max; - - // note: n_past is not yet increased for the `id` token sampled above - // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); - - if (slot.n_predict > 0) { - n_draft_max = std::min(n_draft_max, slot.n_predict - slot.n_decoded - 1); - } - - LOG_VERBOSE("max possible draft", { - {"id_slot", slot.id}, - {"n_draft_max", n_draft_max} - }); - - if (n_draft_max < slot.params.speculative.n_min) { - LOG_VERBOSE("the max possible draft is too small", { - {"id_slot", slot.id}, - {"n_draft_max", n_draft_max}, - {"n_min", slot.params.speculative.n_min} - }); - continue; - } - - llama_token id = slot.sampled; - - struct llama_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = cparams_dft.n_ctx - slot.params.speculative.n_max; - params_spec.p_min = slot.params.speculative.p_min; - - const std::vector & cached_text_tokens = slot.cache_tokens.tokens_data(); - std::vector draft = llama_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); - - // ignore small drafts - if (slot.params.speculative.n_min > (int) draft.size()) { - LOG_VERBOSE("ignoring small draft", { - {"id_slot", slot.id}, - {"draft_size", (int) draft.size()}, - {"n_min", slot.params.speculative.n_min} - }); - continue; - } - - // keep track of total number of drafted tokens tested - slot.n_draft_total += draft.size(); - - // construct the speculation batch - llama_batch_clear(slot.batch_spec); - llama_batch_add(slot.batch_spec, id, slot.cache_tokens.pos_next(), { slot.id }, true); - - for (size_t i = 0; i < draft.size(); ++i) { - llama_batch_add(slot.batch_spec, draft[i], slot.cache_tokens.pos_next() + 1 + i, { slot.id }, true); - } - - LOG_VERBOSE("decoding speculative batch", { - {"id_slot", slot.id}, - {"size", slot.batch_spec.n_tokens} - }); - - llama_decode(ctx, slot.batch_spec); - - // the accepted tokens from the speculation - std::vector ids = llama_sampling_sample_and_accept_n(slot.ctx_sampling, ctx, draft); - - slot.n_past += ids.size(); - slot.n_decoded += ids.size(); - - // update how many tokens out of those tested were accepted - slot.n_draft_accepted += ids.size() - 1; - - slot.cache_tokens.push_back(id); - slot.cache_tokens.insert({ ids.begin(), ids.end() - 1 }); - - llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); - - for (size_t i = 0; i < ids.size(); ++i) { - completion_token_output result; - - result.tok = ids[i]; - result.text_to_send = llama_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // set later - - if (slot.sparams.n_probs > 0) { - populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, i); - } - - if (!process_token(result, slot)) { - // release slot because of stop condition - slot.release(); - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - break; - } - } - - LOG_VERBOSE("speculative decoding result", { - {"id_slot", slot.id}, - {"accepted", (int) ids.size() - 1}, - {"total", (int) draft.size()}, - {"new_n_past", slot.n_past} - }); - } - } - - LOG_VERBOSE("run slots completed", {}); - } - - json model_meta() const { - return json { - {"vocab_type", llama_vocab_type (model)}, - {"n_vocab", llama_n_vocab (model)}, - {"n_ctx_train", llama_n_ctx_train (model)}, - {"n_embd", llama_n_embd (model)}, - {"n_params", llama_model_n_params(model)}, - {"size", llama_model_size (model)}, - }; - } -}; - static json format_final_response_oaicompat(const json& request, json result, const std::string& completion_id, bool streaming = false) { bool stopped_word = result.count("stopped_word") != 0; bool stopped_eos = json_value(result, "stopped_eos", false); @@ -4408,31 +236,6 @@ static std::vector format_partial_response_oaicompat(server_task_result ta } -//static json format_embeddings_response_oaicompat(const json& request, const json& embeddings) { -// json data = json::array(); -// int32_t n_tokens = 0; -// int i = 0; -// for (auto& elem : embeddings) { -// data.push_back(json{ -// {"embedding", json_value(elem, "embedding", json::array())}, -// {"index", i++}, -// {"object", "embedding"} -// }); -// } -// -// json res = json{ -// {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, -// {"object", "list"}, -// {"usage", json { -// {"prompt_tokens", n_tokens}, -// {"total_tokens", n_tokens} -// }}, -// {"data", data} -// }; -// -// return res; -//} - static json format_embeddings_response_oaicompat(const json& request, const json& embeddings, bool use_base64 = false) { json data = json::array(); int32_t n_tokens = 0; @@ -5097,7 +900,7 @@ int main(int argc, char ** argv) { if (ctx_server.params.use_jinja) { if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { data["chat_template_tool_use"] = tool_use_src; - } + } } res.set_content(data.dump(), "application/json; charset=utf-8"); }; @@ -5400,7 +1203,7 @@ int main(int argc, char ** argv) { std::string content; if (body.count("tokens") != 0) { const std::vector tokens = body.at("tokens"); - content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); + content = tokens_to_str(ctx_server.ctx, tokens); } const json data = format_detokenized_response(content); @@ -5555,7 +1358,7 @@ int main(int argc, char ** argv) { {"filesize", entry.file_size()}, {"mtime", str_time}, {"token_count", n_token_count}, - {"prompt", tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend())} + {"prompt", tokens_to_str(ctx_server.ctx, tokens)} }); } } catch (const std::exception& e) { diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 592d3998..b4ff8913 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -773,7 +773,6 @@ void launch_fattn( size_t nb11 = K->nb[1]; size_t nb12 = K->nb[2]; size_t nb13 = K->nb[3]; - char * V_data = (char *) V->data; size_t nb21 = V->nb[1]; size_t nb22 = V->nb[2];