Refactor chat and server file (#1062)

* Add alternative log functions

* chat: fix int overflow, prevent size calculation in float/double (#17357)

* chat: fix int overflow, prevent size calculation in float/double

* Update common/chat.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* common : move all common_chat_parse_* to chat-parser.cpp. (#17481)

# Conflicts:
#	common/chat.cpp

* server: split server.cpp code into server/common/task/queue/context

* Fix compiler warning

* Clean up code

* common: use native MultiByteToWideChar

* move server prompt to server task

* Clean code

* delete utils.hpp

---------

Co-authored-by: firecoperana <firecoperana>
Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: DAN™ <dranger003@gmail.com>
This commit is contained in:
firecoperana
2025-12-15 01:27:20 -06:00
committed by GitHub
parent 0a36cea555
commit 090f354d33
20 changed files with 6849 additions and 5613 deletions

View File

@@ -71,6 +71,8 @@ add_library(${TARGET} STATIC
json-schema-to-grammar.cpp
train.h
train.cpp
log.cpp
log.h
ngram-cache.h
ngram-cache.cpp
speculative.cpp

File diff suppressed because it is too large Load Diff

View File

@@ -63,6 +63,9 @@ class common_chat_msg_parser {
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
bool add_tool_calls(const nlohmann::ordered_json & arr);
// Adds a tool call using the short form: { "tool_name": { "arg1": val, "arg2": val } }
bool add_tool_call_short_form(const nlohmann::ordered_json& tool_call);
void finish();
bool consume_spaces();

View File

@@ -627,6 +627,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
case COMMON_CHAT_FORMAT_MAGISTRAL: return "Magistral";
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
@@ -638,6 +639,10 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools";
case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2";
case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5";
case COMMON_CHAT_FORMAT_KIMI_K2: return "Kimi K2";
@@ -676,114 +681,6 @@ common_reasoning_format common_reasoning_format_from_name(const std::string& for
throw std::runtime_error("Unknown reasoning format: " + format);
}
static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
std::string arguments;
if (builder.is_partial()) {
arguments = (json {{"code", code + builder.healing_marker()}}).dump();
auto idx = arguments.find(builder.healing_marker());
if (idx != std::string::npos) {
arguments.resize(idx);
}
} else {
arguments = (json {{"code", code}}).dump();
}
return arguments;
}
/**
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
* Aggregates the prefix, suffix and in-between text into the content.
*/
static void parse_json_tool_calls(
common_chat_msg_parser & builder,
const std::optional<common_regex> & block_open,
const std::optional<common_regex> & function_regex_start_only,
const std::optional<common_regex> & function_regex,
const common_regex & close_regex,
const std::optional<common_regex> & block_close,
bool allow_raw_python = false,
const std::function<std::string(const common_chat_msg_parser::find_regex_result & fres)> & get_function_name = nullptr) {
auto parse_tool_calls = [&]() {
size_t from = std::string::npos;
auto first = true;
while (true) {
auto start_pos = builder.pos();
auto res = function_regex_start_only && first
? builder.try_consume_regex(*function_regex_start_only)
: function_regex
? builder.try_find_regex(*function_regex, from)
: std::nullopt;
if (res) {
std::string name;
if (get_function_name) {
name = get_function_name(*res);
} else {
GGML_ASSERT(res->groups.size() == 2);
name = builder.str(res->groups[1]);
}
first = false;
if (name.empty()) {
// get_function_name signalled us that we should skip this match and treat it as content.
from = res->groups[0].begin + 1;
continue;
}
from = std::string::npos;
auto maybe_raw_python = name == "python" && allow_raw_python;
if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_regex(close_regex);
}
continue;
}
if (maybe_raw_python) {
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
if (!builder.add_tool_call(name, "", arguments)) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
return;
}
throw common_chat_msg_partial_exception("incomplete tool call");
} else {
builder.move_to(start_pos);
}
break;
}
if (block_close) {
builder.consume_regex(*block_close);
}
builder.consume_spaces();
builder.add_content(builder.consume_rest());
};
if (block_open) {
if (auto res = builder.try_find_regex(*block_open)) {
parse_tool_calls();
} else {
builder.add_content(builder.consume_rest());
}
} else {
parse_tool_calls();
}
}
static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) {
static const std::vector<std::vector<std::string>> args_paths = {{"arguments"}};
if (auto res = builder.try_find_regex(prefix)) {
builder.move_back(rstrip_prefix);
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call array");
}
} else {
builder.add_content(builder.consume_rest());
}
}
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
for (const auto & tool : tools) {
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
@@ -915,37 +812,6 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
data.format = COMMON_CHAT_FORMAT_GENERIC;
return data;
}
static void common_chat_parse_generic(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const std::vector<std::vector<std::string>> content_paths = {
{"response"},
};
static const std::vector<std::vector<std::string>> args_paths = {
{"tool_call", "arguments"},
{"tool_calls", "arguments"},
};
auto data = builder.consume_json_with_dumped_args(args_paths, content_paths);
if (data.value.contains("tool_calls")) {
if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool calls");
}
} else if (data.value.contains("tool_call")) {
if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
} else if (data.value.contains("response")) {
const auto & response = data.value.at("response");
builder.add_content(response.is_string() ? response.template get<std::string>() : response.dump(2));
if (data.is_partial) {
throw common_chat_msg_partial_exception("incomplete response");
}
} else {
throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON");
}
}
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
@@ -991,16 +857,6 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
return data;
}
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
parse_prefixed_json_tool_call_array(builder, prefix);
}
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
@@ -1081,39 +937,6 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
return data;
}
static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
static const common_regex start_action_regex("<\\|START_ACTION\\|>");
static const common_regex end_action_regex("<\\|END_ACTION\\|>");
static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
if (auto res = builder.try_find_regex(start_action_regex)) {
// If we didn't extract thoughts, prelude includes them.
auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
for (const auto & tool_call : tool_calls.value) {
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : "";
std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : "";
if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
if (tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_regex(end_action_regex);
} else if (auto res = builder.try_find_regex(start_response_regex)) {
if (!builder.try_find_regex(end_response_regex)) {
builder.add_content(builder.consume_rest());
throw common_chat_msg_partial_exception(end_response_regex.str());
}
} else {
builder.add_content(builder.consume_rest());
}
}
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
@@ -1212,63 +1035,6 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
});
return data;
}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex function_regex(
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
static const common_regex close_regex("\\}\\s*");
static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\(");
static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*");
if (with_builtin_tools) {
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
if (auto res = builder.try_find_regex(builtin_call_regex)) {
auto fun_res = builder.consume_regex(function_name_regex);
auto function_name = builder.str(fun_res.groups[1]);
common_healing_marker healing_marker;
json args = json::object();
while (true) {
if (auto arg_res = builder.try_consume_regex(arg_name_regex)) {
auto arg_name = builder.str(arg_res->groups[1]);
auto partial = builder.consume_json();
args[arg_name] = partial.json;
healing_marker.marker = partial.healing_marker.marker;
healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker;
builder.consume_spaces();
if (!builder.try_consume_literal(",")) {
break;
}
} else {
break;
}
}
builder.consume_literal(")");
builder.consume_spaces();
auto arguments = args.dump();
if (!builder.add_tool_call(function_name, "", arguments)) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
return;
}
}
parse_json_tool_calls(
builder,
/* block_open= */ std::nullopt,
/* function_regex_start_only= */ function_regex,
/* function_regex= */ std::nullopt,
close_regex,
std::nullopt);
}
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
@@ -1408,88 +1174,6 @@ static common_chat_params common_chat_params_init_deepseek_v3_1(const common_cha
return data;
}
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
static const common_regex function_regex("(?:<tool▁call▁begin>)?function<tool▁sep>([^\n]+)\n```json\n");
static const common_regex close_regex("```[\\s\\r\\n]*<tool▁call▁end>");
parse_json_tool_calls(
builder,
/* block_open= */ tool_calls_begin,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
tool_calls_end);
}
static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) {
static const common_regex function_regex("(?:<tool▁call▁begin>)?([^\\n<]+)(?:<tool▁sep>)");
static const common_regex close_regex("(?:[\\s]*)?<tool▁call▁end>");
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
if (!builder.syntax().parse_tool_calls) {
LOG("%s: not parse_tool_calls\n", __func__);
builder.add_content(builder.consume_rest());
return;
}
LOG("%s: parse_tool_calls\n", __func__);
parse_json_tool_calls(
builder,
/* block_open= */ tool_calls_begin,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
tool_calls_end);
}
static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
// DeepSeek V3.1 outputs reasoning content between "<think>" and "</think>" tags, followed by regular content
// First try to parse using the standard reasoning parsing method
LOG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str());
auto start_pos = builder.pos();
auto found_end_think = builder.try_find_literal("</think>");
builder.move_to(start_pos);
if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) {
LOG("%s: no end_think, not partial, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
} else if (builder.try_parse_reasoning("<think>", "</think>")) {
// If reasoning was parsed successfully, the remaining content is regular content
LOG("%s: parsed reasoning, adding content\n", __func__);
// </think><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>NAME\n```json\nJSON\n```<tool▁call▁end><tool▁calls▁end>
common_chat_parse_deepseek_v3_1_content(builder);
} else {
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) {
LOG("%s: reasoning_format none, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
return;
}
// If no reasoning tags found, check if we should treat everything as reasoning
if (builder.syntax().thinking_forced_open) {
// If thinking is forced open but no tags found, treat everything as reasoning
LOG("%s: thinking_forced_open, adding reasoning content\n", __func__);
builder.add_reasoning_content(builder.consume_rest());
} else {
LOG("%s: no thinking_forced_open, adding content\n", __func__);
// <tool▁call▁begin>NAME<tool▁sep>JSON<tool▁call▁end>
common_chat_parse_deepseek_v3_1_content(builder);
}
}
}
static common_chat_params common_chat_params_init_minimax_m2(const common_chat_template & tmpl, const struct templates_params & params) {
common_chat_params data;
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
@@ -1532,20 +1216,6 @@ static common_chat_params common_chat_params_init_minimax_m2(const common_chat_t
return data;
}
static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "<minimax:tool_call>",
/* form.tool_start = */ "<invoke name=\"",
/* form.tool_sep = */ "\">",
/* form.key_start = */ "<parameter name=\"",
/* form.key_val_sep = */ "\">",
/* form.val_end = */ "</parameter>",
/* form.tool_end = */ "</invoke>",
/* form.scope_end = */ "</minimax:tool_call>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) {
common_chat_params data;
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
@@ -1578,23 +1248,6 @@ static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_c
return data;
}
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<tool_call>";
form.tool_start = "<function=";
form.tool_sep = ">";
form.key_start = "<parameter=";
form.key_val_sep = ">";
form.val_end = "</parameter>";
form.tool_end = "</function>";
form.scope_end = "</tool_call>";
form.trim_raw_argval = true;
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form);
}
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, const struct templates_params & params) {
common_chat_params data;
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
@@ -1639,25 +1292,6 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
return data;
}
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<|tool_calls_section_begin|>";
form.tool_start = "<|tool_call_begin|>";
form.tool_sep = "<|tool_call_argument_begin|>{";
form.key_start = "\"";
form.key_val_sep = "\":";
form.val_end = ",";
form.tool_end = "}<|tool_call_end|>";
form.scope_end = "<|tool_calls_section_end|>";
form.raw_argval = false;
form.last_val_end = "";
form.allow_toolcall_in_think = true;
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_template & tmpl, const struct templates_params & params) {
common_chat_params data;
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
@@ -1693,25 +1327,6 @@ static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_t
return data;
}
static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<tool_calls>[";
form.tool_start = "{\"name\": \"";
form.tool_sep = "\", \"arguments\": {";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}, ";
form.scope_end = "]</tool_calls>";
form.raw_argval = false;
form.last_val_end = "";
form.last_tool_end = "}";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form, "<thinking>", "</thinking>");
}
static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_template & tmpl, const struct templates_params & params) {
common_chat_params data;
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
@@ -1744,24 +1359,6 @@ static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_
return data;
}
static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "";
form.tool_start = "<tool_call>\n{\"name\": \"";
form.tool_sep = "\", \"arguments\": {";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}\n</tool_call>";
form.scope_end = "";
form.raw_argval = false;
form.last_val_end = "";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form);
}
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
auto prompt = apply(tmpl, inputs);
@@ -1892,93 +1489,6 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
return data;
}
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))";
static const std::string recipient("(?: to=functions\\.([^<\\s]+))");
static const common_regex start_regex("<\\|start\\|>assistant");
static const common_regex analysis_regex("<\\|channel\\|>analysis");
static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?");
static const common_regex preamble_regex("<\\|channel\\|>commentary");
static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?");
static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?");
auto consume_end = [&](bool include_end = false) {
if (auto res = builder.try_find_literal("<|end|>")) {
return res->prelude + (include_end ? builder.str(res->groups[0]) : "");
}
return builder.consume_rest();
};
auto handle_tool_call = [&](const std::string & name) {
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
if (builder.syntax().parse_tool_calls) {
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
} else if (args->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
};
auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional<common_regex_match> {
auto match = regex.search(input, 0, true);
if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) {
return match;
}
return std::nullopt;
};
do {
auto header_start_pos = builder.pos();
auto content_start = builder.try_find_literal("<|message|>");
if (!content_start) {
throw common_chat_msg_partial_exception("incomplete header");
}
auto header = content_start->prelude;
if (auto match = regex_match(tool_call1_regex, header)) {
auto group = match->groups[1];
auto name = header.substr(group.begin, group.end - group.begin);
handle_tool_call(name);
continue;
}
if (auto match = regex_match(tool_call2_regex, header)) {
auto group = match->groups[2];
auto name = header.substr(group.begin, group.end - group.begin);
handle_tool_call(name);
continue;
}
if (regex_match(analysis_regex, header)) {
builder.move_to(header_start_pos);
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
builder.add_content(consume_end(true));
} else {
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>");
}
continue;
}
if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) {
builder.add_content(consume_end());
continue;
}
// Possibly a malformed message, attempt to recover by rolling
// back to pick up the next <|start|>
LOG("%s: unknown header from message: %s\n", __func__, header.c_str());
builder.move_to(header_start_pos);
} while (builder.try_find_regex(start_regex, std::string::npos, false));
auto remaining = builder.consume_rest();
if (!remaining.empty()) {
LOG("%s: content after last message: %s\n", __func__, remaining.c_str());
}
}
static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
@@ -2059,21 +1569,6 @@ static common_chat_params common_chat_params_init_glm_4_5(const common_chat_temp
return data;
}
static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "",
/* form.tool_start = */ "<tool_call>",
/* form.tool_sep = */ "",
/* form.key_start = */ "<arg_key>",
/* form.key_val_sep = */ "</arg_key>",
/* form.val_end = */ "</arg_value>",
/* form.tool_end = */ "</tool_call>",
/* form.scope_end = */ "",
/* form.key_val_sep2 = */ "<arg_value>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
LOG("%s\n", __func__);
common_chat_params data;
@@ -2119,14 +1614,6 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
}
return data;
}
static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape(" functools["));
parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1);
}
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
@@ -2177,34 +1664,6 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
}
return data;
}
static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) {
static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))");
static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))");
static const common_regex close_regex(R"(\s*)");
parse_json_tool_calls(
builder,
std::nullopt,
function_regex_start_only,
function_regex,
close_regex,
std::nullopt,
/* allow_raw_python= */ true,
/* get_function_name= */ [&](const auto & res) -> std::string {
auto at_start = res.groups[0].begin == 0;
auto name = builder.str(res.groups[1]);
if (!name.empty() && name.back() == '{') {
// Unconsume the opening brace '{' to ensure the JSON parsing goes well.
builder.move_back(1);
}
auto idx = name.find_last_not_of("\n{");
name = name.substr(0, idx + 1);
if (at_start && name == "all") {
return "";
}
return name;
});
}
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
@@ -2264,31 +1723,6 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
// TODO: if (has_raw_python)
return data;
}
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
static const common_regex function_regex(R"(<function=(\w+)>)");
static const common_regex close_regex(R"(</function>)");
parse_json_tool_calls(
builder,
/* block_open= */ std::nullopt,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
std::nullopt);
if (auto res = builder.try_find_regex(python_tag_regex)) {
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
builder.add_tool_call("python", "", arguments);
return;
}
}
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
@@ -2405,83 +1839,6 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
return data;
}
static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex open_regex(
"(?:"
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
"(" // match 2 (open_tag)
"<tool_call>"
"|<function_call>"
"|<tool>"
"|<tools>"
"|<response>"
"|<json>"
"|<xml>"
"|<JSON>"
")?"
"(\\s*\\{\\s*\"name\")" // match 3 (named tool call)
")"
"|<function=([^>]+)>" // match 4 (function name)
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
);
while (auto res = builder.try_find_regex(open_regex)) {
const auto & block_start = res->groups[1];
std::string block_end = block_start.empty() ? "" : "```";
const auto & open_tag = res->groups[2];
std::string close_tag;
if (!res->groups[3].empty()) {
builder.move_to(res->groups[3].begin);
close_tag = open_tag.empty() ? "" : "</" + builder.str(open_tag).substr(1);
if (auto tool_call = builder.try_consume_json_with_dumped_args({{"arguments"}})) {
if (!builder.add_tool_call(tool_call->value) || tool_call->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_spaces();
builder.consume_literal(close_tag);
builder.consume_spaces();
if (!block_end.empty()) {
builder.consume_literal(block_end);
builder.consume_spaces();
}
} else {
throw common_chat_msg_partial_exception("failed to parse tool call");
}
} else {
auto function_name = builder.str(res->groups[4]);
if (function_name.empty()) {
function_name = builder.str(res->groups[5]);
}
GGML_ASSERT(!function_name.empty());
close_tag = "</function>";
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_spaces();
builder.consume_literal(close_tag);
builder.consume_spaces();
if (!block_end.empty()) {
builder.consume_literal(block_end);
builder.consume_spaces();
}
}
}
}
builder.add_content(builder.consume_rest());
}
static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
@@ -2564,53 +1921,6 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
return data;
}
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
// Parse thinking tags
static const common_regex start_think_regex(regex_escape("<think>"));
static const common_regex end_think_regex(regex_escape("</think>"));
// Granite models output partial tokens such as "<" and "<think".
// By leveraging try_consume_regex()/try_find_regex() throwing
// common_chat_msg_partial_exception for these partial tokens,
// processing is interrupted and the tokens are not passed to add_content().
if (auto res = builder.try_consume_regex(start_think_regex)) {
// Restore position for try_parse_reasoning()
builder.move_to(res->groups[0].begin);
builder.try_find_regex(end_think_regex, std::string::npos, false);
// Restore position for try_parse_reasoning()
builder.move_to(res->groups[0].begin);
}
builder.try_parse_reasoning("<think>", "</think>");
// Parse response tags
static const common_regex start_response_regex(regex_escape("<response>"));
static const common_regex end_response_regex(regex_escape("</response>"));
// Granite models output partial tokens such as "<" and "<response".
// Same hack as reasoning parsing.
if (builder.try_consume_regex(start_response_regex)) {
builder.try_find_regex(end_response_regex);
}
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) {
if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
} else {
builder.add_content(builder.consume_rest());
}
}
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
@@ -2802,7 +2112,7 @@ static common_chat_params common_chat_templates_apply_legacy(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
int alloc_size = 0;
size_t alloc_size = 0;
std::vector<llama_chat_message> chat;
std::vector<std::string> contents;
@@ -2824,7 +2134,8 @@ static common_chat_params common_chat_templates_apply_legacy(
const auto & msg = inputs.messages[i];
const auto & content = contents[i];
chat.push_back({msg.role.c_str(), content.c_str()});
alloc_size += (msg.role.size() + content.size()) * 1.25;
size_t msg_size = msg.role.size() + content.size();
alloc_size += msg_size + (msg_size / 4); // == msg_size * 1.25 but avoiding float ops
}
std::vector<char> buf(alloc_size);
@@ -2846,6 +2157,11 @@ static common_chat_params common_chat_templates_apply_legacy(
res = llama_chat_apply_template(nullptr, src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
}
// for safety, we check the result again
if (res < 0 || (size_t) res > buf.size()) {
throw std::runtime_error("failed to apply chat template, try using --jinja");
}
common_chat_params params;
params.prompt = std::string(buf.data(), res);
if (!inputs.json_schema.empty()) {
@@ -2865,97 +2181,3 @@ common_chat_params common_chat_templates_apply(
? common_chat_templates_apply_jinja(tmpls, inputs)
: common_chat_templates_apply_legacy(tmpls, inputs);
}
static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
builder.add_content(builder.consume_rest());
}
static void common_chat_parse(common_chat_msg_parser & builder) {
LOG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str());
switch (builder.syntax().format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
common_chat_parse_content_only(builder);
break;
case COMMON_CHAT_FORMAT_GENERIC:
common_chat_parse_generic(builder);
break;
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
common_chat_parse_mistral_nemo(builder);
break;
case COMMON_CHAT_FORMAT_LLAMA_3_X:
common_chat_parse_llama_3_1(builder);
break;
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true);
break;
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
common_chat_parse_deepseek_r1(builder);
break;
case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1:
common_chat_parse_deepseek_v3_1(builder);
break;
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
common_chat_parse_functionary_v3_2(builder);
break;
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
common_chat_parse_functionary_v3_1_llama_3_1(builder);
break;
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
common_chat_parse_hermes_2_pro(builder);
break;
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
common_chat_parse_firefunction_v2(builder);
break;
case COMMON_CHAT_FORMAT_COMMAND_R7B:
common_chat_parse_command_r7b(builder);
break;
case COMMON_CHAT_FORMAT_GRANITE:
common_chat_parse_granite(builder);
break;
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
case COMMON_CHAT_FORMAT_MINIMAX_M2:
common_chat_parse_minimax_m2(builder);
break;
case COMMON_CHAT_FORMAT_GLM_4_5:
common_chat_parse_glm_4_5(builder);
break;
case COMMON_CHAT_FORMAT_KIMI_K2:
common_chat_parse_kimi_k2(builder);
break;
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
common_chat_parse_qwen3_coder_xml(builder);
break;
case COMMON_CHAT_FORMAT_APRIEL_1_5:
common_chat_parse_apriel_1_5(builder);
break;
case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
common_chat_parse_xiaomi_mimo(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
builder.finish();
}
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
common_chat_msg_parser builder(input, is_partial, syntax);
try {
common_chat_parse(builder);
} catch (const common_chat_msg_partial_exception & ex) {
LOG("Partial parse: %s\n", ex.what());
if (!is_partial) {
builder.clear_tools();
builder.move_to(0);
common_chat_parse_content_only(builder);
}
}
auto msg = builder.result();
if (!is_partial) {
LOG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
}
return msg;
}

View File

@@ -101,6 +101,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_CONTENT_ONLY,
COMMON_CHAT_FORMAT_GENERIC,
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
COMMON_CHAT_FORMAT_MAGISTRAL,
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
@@ -112,6 +113,10 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_SEED_OSS,
COMMON_CHAT_FORMAT_NEMOTRON_V2,
COMMON_CHAT_FORMAT_APERTUS,
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
COMMON_CHAT_FORMAT_GLM_4_5,
COMMON_CHAT_FORMAT_MINIMAX_M2,
COMMON_CHAT_FORMAT_KIMI_K2,

View File

@@ -2726,11 +2726,29 @@ bool fs_validate_filename(const std::string & filename) {
return true;
}
#ifdef _WIN32
static std::wstring utf8_to_wstring(const std::string& str) {
if (str.empty()) {
return std::wstring();
}
int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
if (size <= 0) {
return std::wstring();
}
std::wstring wstr(size, 0);
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size);
return wstr;
}
#endif
// returns true if successful, false otherwise
bool fs_create_directory_with_parents(const std::string & path) {
#ifdef _WIN32
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
std::wstring wpath = converter.from_bytes(path);
std::wstring wpath = utf8_to_wstring(path);
// if the path already exists, check whether it's a directory
const DWORD attributes = GetFileAttributesW(wpath.c_str());
@@ -3586,175 +3604,6 @@ bool llama_should_add_bos_token(const llama_model * model) {
return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
}
//
// Chat template utils
//
//
//bool llama_chat_verify_template(const struct llama_model* model, const std::string& tmpl, bool use_jinja) {
// if (use_jinja) {
// try {
// auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
// common_chat_inputs inputs;
// inputs.messages = json::array({ {
// {"role", "user"},
// {"content", "test"},
// } });
// common_chat_params_init(chat_template, inputs);
// return true;
// }
// catch (const std::exception& e) {
// fprintf(stdout,"%s: failed to apply template: %s\n", __func__, e.what());
// return false;
// }
// }
// llama_chat_message chat[] = { {"user", "test"} };
// const int res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
// return res >= 0;
//}
//std::string llama_chat_apply_template(const struct llama_model * model,
// const common_chat_template& tmpl,
// const std::vector<common_chat_msg> & msgs,
// bool add_ass,
// bool use_jinja) {
// if (use_jinja) {
// auto messages = json::array();
// for (const auto& msg : msgs) {
// messages.push_back({ {"role", msg.role}, {"content", msg.content} });
// }
// common_chat_inputs inputs;
// inputs.messages = messages;
// inputs.add_generation_prompt = add_ass;
// return common_chat_params_init(tmpl, inputs).prompt;
// }
// int alloc_size = 0;
// std::vector<llama_chat_message> chat;
// for (auto & msg : msgs) {
// chat.push_back({msg.role.c_str(), msg.content.c_str()});
// alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
// }
//
// std::vector<char> buf(alloc_size);
//
// // run the first time to get the total output length
// int32_t res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// // error: chat template is not supported
// if (res < 0) {
// // if the custom "tmpl" is not supported, we throw an error
// // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
// throw std::runtime_error("this custom template is not supported");
// }
//
// // if it turns out that our buffer is too small, we resize it
// if ((size_t)res > buf.size()) {
// buf.resize(res);
// res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// }
//
// std::string formatted_chat(buf.data(), res);
// return formatted_chat;
//}
////
//std::string llama_chat_format_single(const struct llama_model * model,
// const common_chat_template& tmpl,
// const std::vector<common_chat_msg> & past_msg,
// const common_chat_msg & new_msg,
// bool add_ass,
// bool use_jinja) {
// std::ostringstream ss;
// auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja);
// std::vector<common_chat_msg> chat_new(past_msg);
// // if the past_msg ends with a newline, we must preserve it in the formatted version
// if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
// ss << "\n";
// };
// // format chat with new_msg
// chat_new.push_back(new_msg);
// auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja);
// // get the diff part
// ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
// return ss.str();
//}
//std::string llama_chat_format_example(const struct llama_model * model, const common_chat_template& tmpl, bool use_jinja) {
// std::vector<common_chat_msg> msgs = {
// {"system", "You are a helpful assistant", {}},
// {"user", "Hello", {}},
// {"assistant", "Hi there", {}},
// {"user", "How are you?", {}},
// };
// return llama_chat_apply_template(model, tmpl, msgs, true, use_jinja);
//}
//
//#define CHATML_TEMPLATE_SRC \
// "{%- for message in messages -%}\n" \
// " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
// "{%- endfor -%}\n" \
// "{%- if add_generation_prompt -%}\n" \
// " {{- '<|im_start|>assistant\n' -}}\n" \
// "{%- endif -%}"
//
//common_chat_templates llama_chat_templates_from_model(const struct llama_model* model, const std::string& chat_template_override)
//{
// std::string default_template_src;
// std::string template_tool_use_src;
// bool has_explicit_template = !chat_template_override.empty();
// if (chat_template_override.empty()) {
// auto str = llama_model_chat_template(model, /* name */ nullptr);
// if (str) {
// default_template_src = str;
// has_explicit_template = true;
// }
// str = llama_model_chat_template(model, /* name */ "tool_use");
// if (str) {
// template_tool_use_src = str;
// has_explicit_template = true;
// }
// }
// else {
// default_template_src = chat_template_override;
// }
// if (default_template_src.empty() || default_template_src == "chatml") {
// if (!template_tool_use_src.empty()) {
// default_template_src = template_tool_use_src;
// }
// else {
// default_template_src = CHATML_TEMPLATE_SRC;
// }
// }
// auto vocab = llama_model_get_vocab(model);
// const auto get_token = [&](llama_token token, const char* name, const char* jinja_variable_name) {
// if (token == LLAMA_TOKEN_NULL) {
// if (default_template_src.find(jinja_variable_name) != std::string::npos
// || template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
// fprintf(stdout, "%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
// }
// return std::string();
// }
// else {
// return llama_token_to_piece(model, token, true);
// }
// };
// auto token_bos = get_token(llama_token_bos_impl(*vocab), "BOS", "bos_token");
// auto token_eos = get_token(llama_token_eos_impl(*vocab), "EOS", "eos_token");
// try {
// return {
// has_explicit_template,
// std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
// template_tool_use_src.empty()
// ? nullptr
// : std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
// };
// }
// catch (const std::exception& e) {
// LOG("%s: failed to parse chat template: %s\n", __func__, e.what());
// return {
// has_explicit_template,
// std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
// nullptr,
// };
// }
//}
//
// KV cache utils

464
common/log.cpp Normal file
View File

@@ -0,0 +1,464 @@
#include "log.h"
#include <chrono>
#include <condition_variable>
#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <mutex>
#include <sstream>
#include <thread>
#include <vector>
#if defined(_WIN32)
# include <io.h>
# include <windows.h>
# define isatty _isatty
# define fileno _fileno
#else
# include <unistd.h>
#endif // defined(_WIN32)
int common_log_verbosity_thold = LOG_DEFAULT_LLAMA;
void common_log_set_verbosity_thold(int verbosity) {
common_log_verbosity_thold = verbosity;
}
// Auto-detect if colors should be enabled based on terminal and environment
static bool common_log_should_use_colors_auto() {
// Check NO_COLOR environment variable (https://no-color.org/)
if (const char * no_color = std::getenv("NO_COLOR")) {
if (no_color[0] != '\0') {
return false;
}
}
// Check TERM environment variable
if (const char * term = std::getenv("TERM")) {
if (std::strcmp(term, "dumb") == 0) {
return false;
}
}
// Check if stdout and stderr are connected to a terminal
// We check both because log messages can go to either
bool stdout_is_tty = isatty(fileno(stdout));
bool stderr_is_tty = isatty(fileno(stderr));
return stdout_is_tty || stderr_is_tty;
}
static int64_t t_us() {
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
}
// colors
enum common_log_col : int {
COMMON_LOG_COL_DEFAULT = 0,
COMMON_LOG_COL_BOLD,
COMMON_LOG_COL_RED,
COMMON_LOG_COL_GREEN,
COMMON_LOG_COL_YELLOW,
COMMON_LOG_COL_BLUE,
COMMON_LOG_COL_MAGENTA,
COMMON_LOG_COL_CYAN,
COMMON_LOG_COL_WHITE,
};
// disable colors by default
static std::vector<const char *> g_col = {
"",
"",
"",
"",
"",
"",
"",
"",
"",
};
struct common_log_entry {
enum ggml_log_level level;
bool prefix;
int64_t timestamp;
std::vector<char> msg;
// signals the worker thread to stop
bool is_end;
void print(FILE * file = nullptr) const {
FILE * fcur = file;
if (!fcur) {
// stderr displays DBG messages only when their verbosity level is not higher than the threshold
// these messages will still be logged to a file
if (level == GGML_LOG_LEVEL_DEBUG && common_log_verbosity_thold < LOG_DEFAULT_DEBUG) {
return;
}
fcur = stdout;
if (level != GGML_LOG_LEVEL_NONE) {
fcur = stderr;
}
}
if (level != GGML_LOG_LEVEL_NONE && level != GGML_LOG_LEVEL_CONT && prefix) {
if (timestamp) {
// [M.s.ms.us]
fprintf(fcur, "%s%d.%02d.%03d.%03d%s ",
g_col[COMMON_LOG_COL_BLUE],
(int) (timestamp / 1000000 / 60),
(int) (timestamp / 1000000 % 60),
(int) (timestamp / 1000 % 1000),
(int) (timestamp % 1000),
g_col[COMMON_LOG_COL_DEFAULT]);
}
switch (level) {
case GGML_LOG_LEVEL_INFO: fprintf(fcur, "%sI %s", g_col[COMMON_LOG_COL_GREEN], g_col[COMMON_LOG_COL_DEFAULT]); break;
case GGML_LOG_LEVEL_WARN: fprintf(fcur, "%sW %s", g_col[COMMON_LOG_COL_MAGENTA], "" ); break;
case GGML_LOG_LEVEL_ERROR: fprintf(fcur, "%sE %s", g_col[COMMON_LOG_COL_RED], "" ); break;
case GGML_LOG_LEVEL_DEBUG: fprintf(fcur, "%sD %s", g_col[COMMON_LOG_COL_YELLOW], "" ); break;
default:
break;
}
}
fprintf(fcur, "%s", msg.data());
if (level == GGML_LOG_LEVEL_WARN || level == GGML_LOG_LEVEL_ERROR || level == GGML_LOG_LEVEL_DEBUG) {
fprintf(fcur, "%s", g_col[COMMON_LOG_COL_DEFAULT]);
}
fflush(fcur);
}
};
struct common_log {
// default capacity - will be expanded if needed
common_log() : common_log(256) {}
common_log(size_t capacity) {
file = nullptr;
prefix = false;
timestamps = false;
running = false;
t_start = t_us();
// initial message size - will be expanded if longer messages arrive
entries.resize(capacity);
for (auto & entry : entries) {
entry.msg.resize(256);
}
head = 0;
tail = 0;
resume();
}
~common_log() {
pause();
if (file) {
fclose(file);
}
}
private:
std::mutex mtx;
std::thread thrd;
std::condition_variable cv;
FILE * file;
bool prefix;
bool timestamps;
bool running;
int64_t t_start;
// ring buffer of entries
std::vector<common_log_entry> entries;
size_t head;
size_t tail;
// worker thread copies into this
common_log_entry cur;
public:
void add(enum ggml_log_level level, const char * fmt, va_list args) {
std::lock_guard<std::mutex> lock(mtx);
if (!running) {
// discard messages while the worker thread is paused
return;
}
auto & entry = entries[tail];
{
// cannot use args twice, so make a copy in case we need to expand the buffer
va_list args_copy;
va_copy(args_copy, args);
#if 1
const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args);
if (n >= entry.msg.size()) {
entry.msg.resize(n + 1);
vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args_copy);
}
#else
// hack for bolding arguments
std::stringstream ss;
for (int i = 0; fmt[i] != 0; i++) {
if (fmt[i] == '%') {
ss << LOG_COL_BOLD;
while (fmt[i] != ' ' && fmt[i] != ')' && fmt[i] != ']' && fmt[i] != 0) ss << fmt[i++];
ss << LOG_COL_DEFAULT;
if (fmt[i] == 0) break;
}
ss << fmt[i];
}
const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args);
if (n >= entry.msg.size()) {
entry.msg.resize(n + 1);
vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy);
}
#endif
va_end(args_copy);
}
entry.level = level;
entry.prefix = prefix;
entry.timestamp = 0;
if (timestamps) {
entry.timestamp = t_us() - t_start;
}
entry.is_end = false;
tail = (tail + 1) % entries.size();
if (tail == head) {
// expand the buffer
std::vector<common_log_entry> new_entries(2*entries.size());
size_t new_tail = 0;
do {
new_entries[new_tail] = std::move(entries[head]);
head = (head + 1) % entries.size();
new_tail = (new_tail + 1);
} while (head != tail);
head = 0;
tail = new_tail;
for (size_t i = tail; i < new_entries.size(); i++) {
new_entries[i].msg.resize(256);
}
entries = std::move(new_entries);
}
cv.notify_one();
}
void resume() {
std::lock_guard<std::mutex> lock(mtx);
if (running) {
return;
}
running = true;
thrd = std::thread([this]() {
while (true) {
{
std::unique_lock<std::mutex> lock(mtx);
cv.wait(lock, [this]() { return head != tail; });
cur = entries[head];
head = (head + 1) % entries.size();
}
if (cur.is_end) {
break;
}
cur.print(); // stdout and stderr
if (file) {
cur.print(file);
}
}
});
}
void pause() {
{
std::lock_guard<std::mutex> lock(mtx);
if (!running) {
return;
}
running = false;
// push an entry to signal the worker thread to stop
{
auto & entry = entries[tail];
entry.is_end = true;
tail = (tail + 1) % entries.size();
}
cv.notify_one();
}
thrd.join();
}
void set_file(const char * path) {
pause();
if (file) {
fclose(file);
}
if (path) {
file = fopen(path, "w");
} else {
file = nullptr;
}
resume();
}
void set_colors(bool colors) {
pause();
if (colors) {
g_col[COMMON_LOG_COL_DEFAULT] = LOG_COL_DEFAULT;
g_col[COMMON_LOG_COL_BOLD] = LOG_COL_BOLD;
g_col[COMMON_LOG_COL_RED] = LOG_COL_RED;
g_col[COMMON_LOG_COL_GREEN] = LOG_COL_GREEN;
g_col[COMMON_LOG_COL_YELLOW] = LOG_COL_YELLOW;
g_col[COMMON_LOG_COL_BLUE] = LOG_COL_BLUE;
g_col[COMMON_LOG_COL_MAGENTA] = LOG_COL_MAGENTA;
g_col[COMMON_LOG_COL_CYAN] = LOG_COL_CYAN;
g_col[COMMON_LOG_COL_WHITE] = LOG_COL_WHITE;
} else {
for (size_t i = 0; i < g_col.size(); i++) {
g_col[i] = "";
}
}
resume();
}
void set_prefix(bool prefix) {
std::lock_guard<std::mutex> lock(mtx);
this->prefix = prefix;
}
void set_timestamps(bool timestamps) {
std::lock_guard<std::mutex> lock(mtx);
this->timestamps = timestamps;
}
};
//
// public API
//
struct common_log * common_log_init() {
return new common_log;
}
struct common_log * common_log_main() {
static struct common_log log;
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
// Set default to auto-detect colors
log.set_colors(common_log_should_use_colors_auto());
});
return &log;
}
void common_log_pause(struct common_log * log) {
log->pause();
}
void common_log_resume(struct common_log * log) {
log->resume();
}
void common_log_free(struct common_log * log) {
delete log;
}
void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...) {
va_list args;
va_start(args, fmt);
log->add(level, fmt, args);
va_end(args);
}
void common_log_set_file(struct common_log * log, const char * file) {
log->set_file(file);
}
void common_log_set_colors(struct common_log * log, log_colors colors) {
if (colors == LOG_COLORS_AUTO) {
log->set_colors(common_log_should_use_colors_auto());
return;
}
if (colors == LOG_COLORS_DISABLED) {
log->set_colors(false);
return;
}
GGML_ASSERT(colors == LOG_COLORS_ENABLED);
log->set_colors(true);
}
void common_log_set_prefix(struct common_log * log, bool prefix) {
log->set_prefix(prefix);
}
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
log->set_timestamps(timestamps);
}
static int common_get_verbosity(enum ggml_log_level level) {
switch (level) {
case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG;
case GGML_LOG_LEVEL_INFO: return LOG_LEVEL_INFO;
case GGML_LOG_LEVEL_WARN: return LOG_LEVEL_WARN;
case GGML_LOG_LEVEL_ERROR: return LOG_LEVEL_ERROR;
case GGML_LOG_LEVEL_CONT: return LOG_LEVEL_INFO; // same as INFO
case GGML_LOG_LEVEL_NONE:
default:
return LOG_LEVEL_OUTPUT;
}
}
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
auto verbosity = common_get_verbosity(level);
if (verbosity <= common_log_verbosity_thold) {
common_log_add(common_log_main(), level, "%s", text);
}
}

View File

@@ -1,4 +1,5 @@
#pragma once
#include "ggml.h" // for ggml_log_level
#include <chrono>
#include <cstring>
#include <sstream>
@@ -8,6 +9,124 @@
#include <algorithm>
#include <cinttypes>
#define LOG_CLR_TO_EOL "\033[K\r"
#define LOG_COL_DEFAULT "\033[0m"
#define LOG_COL_BOLD "\033[1m"
#define LOG_COL_RED "\033[31m"
#define LOG_COL_GREEN "\033[32m"
#define LOG_COL_YELLOW "\033[33m"
#define LOG_COL_BLUE "\033[34m"
#define LOG_COL_MAGENTA "\033[35m"
#define LOG_COL_CYAN "\033[36m"
#define LOG_COL_WHITE "\033[37m"
#ifndef __GNUC__
# define LOG_ATTRIBUTE_FORMAT(...)
#elif defined(__MINGW32__) && !defined(__clang__)
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
#else
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
#endif
#define LOG_LEVEL_DEBUG 4
#define LOG_LEVEL_INFO 3
#define LOG_LEVEL_WARN 2
#define LOG_LEVEL_ERROR 1
#define LOG_LEVEL_OUTPUT 0 // output data from tools
#define LOG_DEFAULT_DEBUG LOG_LEVEL_DEBUG
#define LOG_DEFAULT_LLAMA LOG_LEVEL_INFO
enum log_colors {
LOG_COLORS_AUTO = -1,
LOG_COLORS_DISABLED = 0,
LOG_COLORS_ENABLED = 1,
};
// needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower
// set via common_log_set_verbosity()
extern int common_log_verbosity_thold;
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
void common_log_default_callback(enum ggml_log_level level, const char* text, void* user_data);
// the common_log uses an internal worker thread to print/write log messages
// when the worker thread is paused, incoming log messages are discarded
struct common_log;
struct common_log* common_log_init();
struct common_log* common_log_main(); // singleton, automatically destroys itself on exit
void common_log_pause(struct common_log* log); // pause the worker thread, not thread-safe
void common_log_resume(struct common_log* log); // resume the worker thread, not thread-safe
void common_log_free(struct common_log* log);
LOG_ATTRIBUTE_FORMAT(3, 4)
void common_log_add(struct common_log* log, enum ggml_log_level level, const char* fmt, ...);
// defaults: file = NULL, colors = false, prefix = false, timestamps = false
//
// regular log output:
//
// ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34)
// llm_load_tensors: ggml ctx size = 0.27 MiB
// llm_load_tensors: offloading 32 repeating layers to GPU
// llm_load_tensors: offloading non-repeating layers to GPU
//
// with prefix = true, timestamps = true, the log output will look like this:
//
// 0.00.035.060 D ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34)
// 0.00.035.064 I llm_load_tensors: ggml ctx size = 0.27 MiB
// 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU
// 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU
//
// D - debug (stderr, V = LOG_DEFAULT_DEBUG)
// I - info (stdout, V = LOG_DEFAULT_INFO)
// W - warning (stderr, V = LOG_DEFAULT_WARN)
// E - error (stderr, V = LOG_DEFAULT_ERROR)
// O - output (stdout, V = LOG_DEFAULT_OUTPUT)
//
void common_log_set_file(struct common_log* log, const char* file); // not thread-safe
void common_log_set_colors(struct common_log* log, log_colors colors); // not thread-safe
void common_log_set_prefix(struct common_log* log, bool prefix); // whether to output prefix to each log
void common_log_set_timestamps(struct common_log* log, bool timestamps); // whether to output timestamps in the prefix
// helper macros for logging
// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold
//
// for example:
//
// LOG_DBG("this is a debug message: %d\n", expensive_function());
//
// this will avoid calling expensive_function() if LOG_DEFAULT_DEBUG > common_log_verbosity_thold
//
#define LOG_TMPL(level, verbosity, ...) \
do { \
if ((verbosity) <= common_log_verbosity_thold) { \
common_log_add(common_log_main(), (level), __VA_ARGS__); \
} \
} while (0)
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, LOG_LEVEL_OUTPUT, __VA_ARGS__)
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_LEVEL_DEBUG, __VA_ARGS__)
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, LOG_LEVEL_INFO, __VA_ARGS__)
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, LOG_LEVEL_WARN, __VA_ARGS__)
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, LOG_LEVEL_ERROR, __VA_ARGS__)
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, LOG_LEVEL_INFO, __VA_ARGS__) // same as INFO
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)
#define LOG_ERRV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, verbosity, __VA_ARGS__)
#define LOG_DBGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, verbosity, __VA_ARGS__)
#define LOG_CNTV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_CONT, verbosity, __VA_ARGS__)
// --------------------------------
//
// Basic usage:
@@ -293,11 +412,11 @@ inline std::string log_filename_generator_impl(LogTriState multilog, const std::
// Main LOG macro.
// behaves like printf, and supports arguments the exact same way.
//
#if !defined(_MSC_VER) || defined(__clang__)
#define LOG(...) LOG_IMPL(__VA_ARGS__, "")
#else
#define LOG(str, ...) LOG_IMPL("%s" str, "", ##__VA_ARGS__, "")
#endif
//#if !defined(_MSC_VER) || defined(__clang__)
// #define LOG(...) LOG_IMPL(__VA_ARGS__, "")
//#else
// #define LOG(str, ...) LOG_IMPL("%s" str, "", ##__VA_ARGS__, "")
//#endif
// Main TEE macro.
// does the same as LOG
@@ -721,3 +840,4 @@ inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch)
#define LOG_DUMP_CMDLINE(...) // dummy stub
#endif // LOG_DISABLE_LOGS