mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Refactor chat and server file (#1062)
* Add alternative log functions * chat: fix int overflow, prevent size calculation in float/double (#17357) * chat: fix int overflow, prevent size calculation in float/double * Update common/chat.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * common : move all common_chat_parse_* to chat-parser.cpp. (#17481) # Conflicts: # common/chat.cpp * server: split server.cpp code into server/common/task/queue/context * Fix compiler warning * Clean up code * common: use native MultiByteToWideChar * move server prompt to server task * Clean code * delete utils.hpp --------- Co-authored-by: firecoperana <firecoperana> Co-authored-by: Xuan-Son Nguyen <son@huggingface.co> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: DAN™ <dranger003@gmail.com>
This commit is contained in:
@@ -71,6 +71,8 @@ add_library(${TARGET} STATIC
|
||||
json-schema-to-grammar.cpp
|
||||
train.h
|
||||
train.cpp
|
||||
log.cpp
|
||||
log.h
|
||||
ngram-cache.h
|
||||
ngram-cache.cpp
|
||||
speculative.cpp
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -63,6 +63,9 @@ class common_chat_msg_parser {
|
||||
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
|
||||
bool add_tool_calls(const nlohmann::ordered_json & arr);
|
||||
|
||||
// Adds a tool call using the short form: { "tool_name": { "arg1": val, "arg2": val } }
|
||||
bool add_tool_call_short_form(const nlohmann::ordered_json& tool_call);
|
||||
|
||||
void finish();
|
||||
|
||||
bool consume_spaces();
|
||||
|
||||
804
common/chat.cpp
804
common/chat.cpp
@@ -627,6 +627,7 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
|
||||
case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
|
||||
case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
|
||||
case COMMON_CHAT_FORMAT_MAGISTRAL: return "Magistral";
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
|
||||
@@ -638,6 +639,10 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
|
||||
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
|
||||
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
|
||||
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
|
||||
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
|
||||
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
|
||||
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools";
|
||||
case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2";
|
||||
case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5";
|
||||
case COMMON_CHAT_FORMAT_KIMI_K2: return "Kimi K2";
|
||||
@@ -676,114 +681,6 @@ common_reasoning_format common_reasoning_format_from_name(const std::string& for
|
||||
throw std::runtime_error("Unknown reasoning format: " + format);
|
||||
}
|
||||
|
||||
static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
|
||||
std::string arguments;
|
||||
if (builder.is_partial()) {
|
||||
arguments = (json {{"code", code + builder.healing_marker()}}).dump();
|
||||
auto idx = arguments.find(builder.healing_marker());
|
||||
if (idx != std::string::npos) {
|
||||
arguments.resize(idx);
|
||||
}
|
||||
} else {
|
||||
arguments = (json {{"code", code}}).dump();
|
||||
}
|
||||
return arguments;
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||
* Aggregates the prefix, suffix and in-between text into the content.
|
||||
*/
|
||||
static void parse_json_tool_calls(
|
||||
common_chat_msg_parser & builder,
|
||||
const std::optional<common_regex> & block_open,
|
||||
const std::optional<common_regex> & function_regex_start_only,
|
||||
const std::optional<common_regex> & function_regex,
|
||||
const common_regex & close_regex,
|
||||
const std::optional<common_regex> & block_close,
|
||||
bool allow_raw_python = false,
|
||||
const std::function<std::string(const common_chat_msg_parser::find_regex_result & fres)> & get_function_name = nullptr) {
|
||||
|
||||
auto parse_tool_calls = [&]() {
|
||||
size_t from = std::string::npos;
|
||||
auto first = true;
|
||||
while (true) {
|
||||
auto start_pos = builder.pos();
|
||||
auto res = function_regex_start_only && first
|
||||
? builder.try_consume_regex(*function_regex_start_only)
|
||||
: function_regex
|
||||
? builder.try_find_regex(*function_regex, from)
|
||||
: std::nullopt;
|
||||
|
||||
if (res) {
|
||||
std::string name;
|
||||
if (get_function_name) {
|
||||
name = get_function_name(*res);
|
||||
} else {
|
||||
GGML_ASSERT(res->groups.size() == 2);
|
||||
name = builder.str(res->groups[1]);
|
||||
}
|
||||
first = false;
|
||||
if (name.empty()) {
|
||||
// get_function_name signalled us that we should skip this match and treat it as content.
|
||||
from = res->groups[0].begin + 1;
|
||||
continue;
|
||||
}
|
||||
from = std::string::npos;
|
||||
|
||||
auto maybe_raw_python = name == "python" && allow_raw_python;
|
||||
if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
|
||||
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
|
||||
if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_regex(close_regex);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (maybe_raw_python) {
|
||||
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
|
||||
if (!builder.add_tool_call(name, "", arguments)) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
return;
|
||||
}
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
} else {
|
||||
builder.move_to(start_pos);
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (block_close) {
|
||||
builder.consume_regex(*block_close);
|
||||
}
|
||||
builder.consume_spaces();
|
||||
builder.add_content(builder.consume_rest());
|
||||
};
|
||||
if (block_open) {
|
||||
if (auto res = builder.try_find_regex(*block_open)) {
|
||||
parse_tool_calls();
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
} else {
|
||||
parse_tool_calls();
|
||||
}
|
||||
}
|
||||
|
||||
static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) {
|
||||
static const std::vector<std::vector<std::string>> args_paths = {{"arguments"}};
|
||||
if (auto res = builder.try_find_regex(prefix)) {
|
||||
builder.move_back(rstrip_prefix);
|
||||
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
|
||||
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call array");
|
||||
}
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
}
|
||||
|
||||
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
|
||||
@@ -915,37 +812,6 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
||||
data.format = COMMON_CHAT_FORMAT_GENERIC;
|
||||
return data;
|
||||
}
|
||||
static void common_chat_parse_generic(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
static const std::vector<std::vector<std::string>> content_paths = {
|
||||
{"response"},
|
||||
};
|
||||
static const std::vector<std::vector<std::string>> args_paths = {
|
||||
{"tool_call", "arguments"},
|
||||
{"tool_calls", "arguments"},
|
||||
};
|
||||
auto data = builder.consume_json_with_dumped_args(args_paths, content_paths);
|
||||
if (data.value.contains("tool_calls")) {
|
||||
if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool calls");
|
||||
}
|
||||
} else if (data.value.contains("tool_call")) {
|
||||
if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
} else if (data.value.contains("response")) {
|
||||
const auto & response = data.value.at("response");
|
||||
builder.add_content(response.is_string() ? response.template get<std::string>() : response.dump(2));
|
||||
if (data.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete response");
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON");
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
@@ -991,16 +857,6 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
||||
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
|
||||
return data;
|
||||
}
|
||||
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
|
||||
parse_prefixed_json_tool_call_array(builder, prefix);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
@@ -1081,39 +937,6 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
|
||||
|
||||
static const common_regex start_action_regex("<\\|START_ACTION\\|>");
|
||||
static const common_regex end_action_regex("<\\|END_ACTION\\|>");
|
||||
static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
|
||||
static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
|
||||
|
||||
if (auto res = builder.try_find_regex(start_action_regex)) {
|
||||
// If we didn't extract thoughts, prelude includes them.
|
||||
auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
|
||||
for (const auto & tool_call : tool_calls.value) {
|
||||
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
|
||||
std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : "";
|
||||
std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : "";
|
||||
if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
}
|
||||
if (tool_calls.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_regex(end_action_regex);
|
||||
} else if (auto res = builder.try_find_regex(start_response_regex)) {
|
||||
if (!builder.try_find_regex(end_response_regex)) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
throw common_chat_msg_partial_exception(end_response_regex.str());
|
||||
}
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
}
|
||||
|
||||
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
|
||||
if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
|
||||
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
|
||||
@@ -1212,63 +1035,6 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
|
||||
});
|
||||
return data;
|
||||
}
|
||||
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
static const common_regex function_regex(
|
||||
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
|
||||
static const common_regex close_regex("\\}\\s*");
|
||||
|
||||
static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\(");
|
||||
static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*");
|
||||
|
||||
if (with_builtin_tools) {
|
||||
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
|
||||
if (auto res = builder.try_find_regex(builtin_call_regex)) {
|
||||
auto fun_res = builder.consume_regex(function_name_regex);
|
||||
auto function_name = builder.str(fun_res.groups[1]);
|
||||
|
||||
common_healing_marker healing_marker;
|
||||
json args = json::object();
|
||||
while (true) {
|
||||
if (auto arg_res = builder.try_consume_regex(arg_name_regex)) {
|
||||
auto arg_name = builder.str(arg_res->groups[1]);
|
||||
auto partial = builder.consume_json();
|
||||
args[arg_name] = partial.json;
|
||||
healing_marker.marker = partial.healing_marker.marker;
|
||||
healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker;
|
||||
builder.consume_spaces();
|
||||
if (!builder.try_consume_literal(",")) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
builder.consume_literal(")");
|
||||
builder.consume_spaces();
|
||||
|
||||
auto arguments = args.dump();
|
||||
if (!builder.add_tool_call(function_name, "", arguments)) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
/* block_open= */ std::nullopt,
|
||||
/* function_regex_start_only= */ function_regex,
|
||||
/* function_regex= */ std::nullopt,
|
||||
close_regex,
|
||||
std::nullopt);
|
||||
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
@@ -1408,88 +1174,6 @@ static common_chat_params common_chat_params_init_deepseek_v3_1(const common_cha
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)");
|
||||
static const common_regex tool_calls_end("<|tool▁calls▁end|>");
|
||||
static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n");
|
||||
static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
|
||||
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
/* block_open= */ tool_calls_begin,
|
||||
/* function_regex_start_only= */ std::nullopt,
|
||||
function_regex,
|
||||
close_regex,
|
||||
tool_calls_end);
|
||||
}
|
||||
|
||||
static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) {
|
||||
static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)");
|
||||
|
||||
static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>");
|
||||
static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)");
|
||||
static const common_regex tool_calls_end("<|tool▁calls▁end|>");
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
LOG("%s: not parse_tool_calls\n", __func__);
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
LOG("%s: parse_tool_calls\n", __func__);
|
||||
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
/* block_open= */ tool_calls_begin,
|
||||
/* function_regex_start_only= */ std::nullopt,
|
||||
function_regex,
|
||||
close_regex,
|
||||
tool_calls_end);
|
||||
}
|
||||
|
||||
static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
|
||||
// DeepSeek V3.1 outputs reasoning content between "<think>" and "</think>" tags, followed by regular content
|
||||
// First try to parse using the standard reasoning parsing method
|
||||
LOG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str());
|
||||
|
||||
auto start_pos = builder.pos();
|
||||
auto found_end_think = builder.try_find_literal("</think>");
|
||||
builder.move_to(start_pos);
|
||||
|
||||
if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) {
|
||||
LOG("%s: no end_think, not partial, adding content\n", __func__);
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
} else if (builder.try_parse_reasoning("<think>", "</think>")) {
|
||||
// If reasoning was parsed successfully, the remaining content is regular content
|
||||
LOG("%s: parsed reasoning, adding content\n", __func__);
|
||||
// </think><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|>
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
} else {
|
||||
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) {
|
||||
LOG("%s: reasoning_format none, adding content\n", __func__);
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
return;
|
||||
}
|
||||
// If no reasoning tags found, check if we should treat everything as reasoning
|
||||
if (builder.syntax().thinking_forced_open) {
|
||||
// If thinking is forced open but no tags found, treat everything as reasoning
|
||||
LOG("%s: thinking_forced_open, adding reasoning content\n", __func__);
|
||||
builder.add_reasoning_content(builder.consume_rest());
|
||||
} else {
|
||||
LOG("%s: no thinking_forced_open, adding content\n", __func__);
|
||||
// <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|>
|
||||
common_chat_parse_deepseek_v3_1_content(builder);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static common_chat_params common_chat_params_init_minimax_m2(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
@@ -1532,20 +1216,6 @@ static common_chat_params common_chat_params_init_minimax_m2(const common_chat_t
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<minimax:tool_call>",
|
||||
/* form.tool_start = */ "<invoke name=\"",
|
||||
/* form.tool_sep = */ "\">",
|
||||
/* form.key_start = */ "<parameter name=\"",
|
||||
/* form.key_val_sep = */ "\">",
|
||||
/* form.val_end = */ "</parameter>",
|
||||
/* form.tool_end = */ "</invoke>",
|
||||
/* form.scope_end = */ "</minimax:tool_call>",
|
||||
};
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
@@ -1578,23 +1248,6 @@ static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_c
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<tool_call>";
|
||||
form.tool_start = "<function=";
|
||||
form.tool_sep = ">";
|
||||
form.key_start = "<parameter=";
|
||||
form.key_val_sep = ">";
|
||||
form.val_end = "</parameter>";
|
||||
form.tool_end = "</function>";
|
||||
form.scope_end = "</tool_call>";
|
||||
form.trim_raw_argval = true;
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
@@ -1639,25 +1292,6 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<|tool_calls_section_begin|>";
|
||||
form.tool_start = "<|tool_call_begin|>";
|
||||
form.tool_sep = "<|tool_call_argument_begin|>{";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\":";
|
||||
form.val_end = ",";
|
||||
form.tool_end = "}<|tool_call_end|>";
|
||||
form.scope_end = "<|tool_calls_section_end|>";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
form.allow_toolcall_in_think = true;
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
@@ -1693,25 +1327,6 @@ static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_t
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<tool_calls>[";
|
||||
form.tool_start = "{\"name\": \"";
|
||||
form.tool_sep = "\", \"arguments\": {";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}, ";
|
||||
form.scope_end = "]</tool_calls>";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
form.last_tool_end = "}";
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<thinking>", "</thinking>");
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
@@ -1744,24 +1359,6 @@ static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "";
|
||||
form.tool_start = "<tool_call>\n{\"name\": \"";
|
||||
form.tool_sep = "\", \"arguments\": {";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}\n</tool_call>";
|
||||
form.scope_end = "";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
auto prompt = apply(tmpl, inputs);
|
||||
@@ -1892,93 +1489,6 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
|
||||
return data;
|
||||
}
|
||||
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
|
||||
static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))";
|
||||
static const std::string recipient("(?: to=functions\\.([^<\\s]+))");
|
||||
|
||||
static const common_regex start_regex("<\\|start\\|>assistant");
|
||||
static const common_regex analysis_regex("<\\|channel\\|>analysis");
|
||||
static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?");
|
||||
static const common_regex preamble_regex("<\\|channel\\|>commentary");
|
||||
static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?");
|
||||
static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?");
|
||||
|
||||
auto consume_end = [&](bool include_end = false) {
|
||||
if (auto res = builder.try_find_literal("<|end|>")) {
|
||||
return res->prelude + (include_end ? builder.str(res->groups[0]) : "");
|
||||
}
|
||||
return builder.consume_rest();
|
||||
};
|
||||
|
||||
auto handle_tool_call = [&](const std::string & name) {
|
||||
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
|
||||
if (builder.syntax().parse_tool_calls) {
|
||||
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
} else if (args->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional<common_regex_match> {
|
||||
auto match = regex.search(input, 0, true);
|
||||
if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) {
|
||||
return match;
|
||||
}
|
||||
return std::nullopt;
|
||||
};
|
||||
|
||||
do {
|
||||
auto header_start_pos = builder.pos();
|
||||
auto content_start = builder.try_find_literal("<|message|>");
|
||||
if (!content_start) {
|
||||
throw common_chat_msg_partial_exception("incomplete header");
|
||||
}
|
||||
|
||||
auto header = content_start->prelude;
|
||||
|
||||
if (auto match = regex_match(tool_call1_regex, header)) {
|
||||
auto group = match->groups[1];
|
||||
auto name = header.substr(group.begin, group.end - group.begin);
|
||||
handle_tool_call(name);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto match = regex_match(tool_call2_regex, header)) {
|
||||
auto group = match->groups[2];
|
||||
auto name = header.substr(group.begin, group.end - group.begin);
|
||||
handle_tool_call(name);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (regex_match(analysis_regex, header)) {
|
||||
builder.move_to(header_start_pos);
|
||||
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
|
||||
builder.add_content(consume_end(true));
|
||||
} else {
|
||||
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) {
|
||||
builder.add_content(consume_end());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Possibly a malformed message, attempt to recover by rolling
|
||||
// back to pick up the next <|start|>
|
||||
LOG("%s: unknown header from message: %s\n", __func__, header.c_str());
|
||||
builder.move_to(header_start_pos);
|
||||
} while (builder.try_find_regex(start_regex, std::string::npos, false));
|
||||
|
||||
auto remaining = builder.consume_rest();
|
||||
if (!remaining.empty()) {
|
||||
LOG("%s: content after last message: %s\n", __func__, remaining.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
@@ -2059,21 +1569,6 @@ static common_chat_params common_chat_params_init_glm_4_5(const common_chat_temp
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "",
|
||||
/* form.tool_start = */ "<tool_call>",
|
||||
/* form.tool_sep = */ "",
|
||||
/* form.key_start = */ "<arg_key>",
|
||||
/* form.key_val_sep = */ "</arg_key>",
|
||||
/* form.val_end = */ "</arg_value>",
|
||||
/* form.tool_end = */ "</tool_call>",
|
||||
/* form.scope_end = */ "",
|
||||
/* form.key_val_sep2 = */ "<arg_value>",
|
||||
};
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
LOG("%s\n", __func__);
|
||||
common_chat_params data;
|
||||
@@ -2119,14 +1614,6 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
|
||||
}
|
||||
return data;
|
||||
}
|
||||
static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
static const common_regex prefix(regex_escape(" functools["));
|
||||
parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||
@@ -2177,34 +1664,6 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||
}
|
||||
return data;
|
||||
}
|
||||
static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) {
|
||||
static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))");
|
||||
static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))");
|
||||
static const common_regex close_regex(R"(\s*)");
|
||||
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
std::nullopt,
|
||||
function_regex_start_only,
|
||||
function_regex,
|
||||
close_regex,
|
||||
std::nullopt,
|
||||
/* allow_raw_python= */ true,
|
||||
/* get_function_name= */ [&](const auto & res) -> std::string {
|
||||
auto at_start = res.groups[0].begin == 0;
|
||||
auto name = builder.str(res.groups[1]);
|
||||
if (!name.empty() && name.back() == '{') {
|
||||
// Unconsume the opening brace '{' to ensure the JSON parsing goes well.
|
||||
builder.move_back(1);
|
||||
}
|
||||
auto idx = name.find_last_not_of("\n{");
|
||||
name = name.substr(0, idx + 1);
|
||||
if (at_start && name == "all") {
|
||||
return "";
|
||||
}
|
||||
return name;
|
||||
});
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
@@ -2264,31 +1723,6 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
||||
// TODO: if (has_raw_python)
|
||||
return data;
|
||||
}
|
||||
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
||||
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
|
||||
|
||||
static const common_regex function_regex(R"(<function=(\w+)>)");
|
||||
static const common_regex close_regex(R"(</function>)");
|
||||
|
||||
parse_json_tool_calls(
|
||||
builder,
|
||||
/* block_open= */ std::nullopt,
|
||||
/* function_regex_start_only= */ std::nullopt,
|
||||
function_regex,
|
||||
close_regex,
|
||||
std::nullopt);
|
||||
|
||||
if (auto res = builder.try_find_regex(python_tag_regex)) {
|
||||
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
|
||||
builder.add_tool_call("python", "", arguments);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
@@ -2405,83 +1839,6 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
|
||||
|
||||
return data;
|
||||
}
|
||||
static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
static const common_regex open_regex(
|
||||
"(?:"
|
||||
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
|
||||
"(" // match 2 (open_tag)
|
||||
"<tool_call>"
|
||||
"|<function_call>"
|
||||
"|<tool>"
|
||||
"|<tools>"
|
||||
"|<response>"
|
||||
"|<json>"
|
||||
"|<xml>"
|
||||
"|<JSON>"
|
||||
")?"
|
||||
"(\\s*\\{\\s*\"name\")" // match 3 (named tool call)
|
||||
")"
|
||||
"|<function=([^>]+)>" // match 4 (function name)
|
||||
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
|
||||
);
|
||||
|
||||
while (auto res = builder.try_find_regex(open_regex)) {
|
||||
const auto & block_start = res->groups[1];
|
||||
std::string block_end = block_start.empty() ? "" : "```";
|
||||
|
||||
const auto & open_tag = res->groups[2];
|
||||
std::string close_tag;
|
||||
|
||||
if (!res->groups[3].empty()) {
|
||||
builder.move_to(res->groups[3].begin);
|
||||
close_tag = open_tag.empty() ? "" : "</" + builder.str(open_tag).substr(1);
|
||||
|
||||
if (auto tool_call = builder.try_consume_json_with_dumped_args({{"arguments"}})) {
|
||||
if (!builder.add_tool_call(tool_call->value) || tool_call->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_spaces();
|
||||
builder.consume_literal(close_tag);
|
||||
builder.consume_spaces();
|
||||
if (!block_end.empty()) {
|
||||
builder.consume_literal(block_end);
|
||||
builder.consume_spaces();
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("failed to parse tool call");
|
||||
}
|
||||
} else {
|
||||
auto function_name = builder.str(res->groups[4]);
|
||||
if (function_name.empty()) {
|
||||
function_name = builder.str(res->groups[5]);
|
||||
}
|
||||
GGML_ASSERT(!function_name.empty());
|
||||
|
||||
close_tag = "</function>";
|
||||
|
||||
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
|
||||
if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_spaces();
|
||||
builder.consume_literal(close_tag);
|
||||
builder.consume_spaces();
|
||||
if (!block_end.empty()) {
|
||||
builder.consume_literal(block_end);
|
||||
builder.consume_spaces();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
@@ -2564,53 +1921,6 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags
|
||||
static const common_regex start_think_regex(regex_escape("<think>"));
|
||||
static const common_regex end_think_regex(regex_escape("</think>"));
|
||||
// Granite models output partial tokens such as "<" and "<think".
|
||||
// By leveraging try_consume_regex()/try_find_regex() throwing
|
||||
// common_chat_msg_partial_exception for these partial tokens,
|
||||
// processing is interrupted and the tokens are not passed to add_content().
|
||||
if (auto res = builder.try_consume_regex(start_think_regex)) {
|
||||
// Restore position for try_parse_reasoning()
|
||||
builder.move_to(res->groups[0].begin);
|
||||
builder.try_find_regex(end_think_regex, std::string::npos, false);
|
||||
// Restore position for try_parse_reasoning()
|
||||
builder.move_to(res->groups[0].begin);
|
||||
}
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
|
||||
// Parse response tags
|
||||
static const common_regex start_response_regex(regex_escape("<response>"));
|
||||
static const common_regex end_response_regex(regex_escape("</response>"));
|
||||
// Granite models output partial tokens such as "<" and "<response".
|
||||
// Same hack as reasoning parsing.
|
||||
if (builder.try_consume_regex(start_response_regex)) {
|
||||
builder.try_find_regex(end_response_regex);
|
||||
}
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// Look for tool calls
|
||||
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
|
||||
if (auto res = builder.try_find_regex(tool_call_regex)) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
// Expect JSON array of tool calls
|
||||
if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) {
|
||||
if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
@@ -2802,7 +2112,7 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||
const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs)
|
||||
{
|
||||
int alloc_size = 0;
|
||||
size_t alloc_size = 0;
|
||||
std::vector<llama_chat_message> chat;
|
||||
std::vector<std::string> contents;
|
||||
|
||||
@@ -2824,7 +2134,8 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||
const auto & msg = inputs.messages[i];
|
||||
const auto & content = contents[i];
|
||||
chat.push_back({msg.role.c_str(), content.c_str()});
|
||||
alloc_size += (msg.role.size() + content.size()) * 1.25;
|
||||
size_t msg_size = msg.role.size() + content.size();
|
||||
alloc_size += msg_size + (msg_size / 4); // == msg_size * 1.25 but avoiding float ops
|
||||
}
|
||||
|
||||
std::vector<char> buf(alloc_size);
|
||||
@@ -2846,6 +2157,11 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||
res = llama_chat_apply_template(nullptr, src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
||||
}
|
||||
|
||||
// for safety, we check the result again
|
||||
if (res < 0 || (size_t) res > buf.size()) {
|
||||
throw std::runtime_error("failed to apply chat template, try using --jinja");
|
||||
}
|
||||
|
||||
common_chat_params params;
|
||||
params.prompt = std::string(buf.data(), res);
|
||||
if (!inputs.json_schema.empty()) {
|
||||
@@ -2865,97 +2181,3 @@ common_chat_params common_chat_templates_apply(
|
||||
? common_chat_templates_apply_jinja(tmpls, inputs)
|
||||
: common_chat_templates_apply_legacy(tmpls, inputs);
|
||||
}
|
||||
|
||||
static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
LOG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str());
|
||||
|
||||
switch (builder.syntax().format) {
|
||||
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
|
||||
common_chat_parse_content_only(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GENERIC:
|
||||
common_chat_parse_generic(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
|
||||
common_chat_parse_mistral_nemo(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X:
|
||||
common_chat_parse_llama_3_1(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
|
||||
common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
|
||||
common_chat_parse_deepseek_r1(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1:
|
||||
common_chat_parse_deepseek_v3_1(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
|
||||
common_chat_parse_functionary_v3_2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
|
||||
common_chat_parse_functionary_v3_1_llama_3_1(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
|
||||
common_chat_parse_hermes_2_pro(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
|
||||
common_chat_parse_firefunction_v2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B:
|
||||
common_chat_parse_command_r7b(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GRANITE:
|
||||
common_chat_parse_granite(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GPT_OSS:
|
||||
common_chat_parse_gpt_oss(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_MINIMAX_M2:
|
||||
common_chat_parse_minimax_m2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GLM_4_5:
|
||||
common_chat_parse_glm_4_5(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_KIMI_K2:
|
||||
common_chat_parse_kimi_k2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
|
||||
common_chat_parse_qwen3_coder_xml(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5:
|
||||
common_chat_parse_apriel_1_5(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
|
||||
common_chat_parse_xiaomi_mimo(builder);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
||||
}
|
||||
builder.finish();
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
||||
common_chat_msg_parser builder(input, is_partial, syntax);
|
||||
try {
|
||||
common_chat_parse(builder);
|
||||
} catch (const common_chat_msg_partial_exception & ex) {
|
||||
LOG("Partial parse: %s\n", ex.what());
|
||||
if (!is_partial) {
|
||||
builder.clear_tools();
|
||||
builder.move_to(0);
|
||||
common_chat_parse_content_only(builder);
|
||||
}
|
||||
}
|
||||
auto msg = builder.result();
|
||||
if (!is_partial) {
|
||||
LOG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
@@ -101,6 +101,7 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
COMMON_CHAT_FORMAT_GENERIC,
|
||||
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
||||
COMMON_CHAT_FORMAT_MAGISTRAL,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||
@@ -112,6 +113,10 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||
COMMON_CHAT_FORMAT_GRANITE,
|
||||
COMMON_CHAT_FORMAT_GPT_OSS,
|
||||
COMMON_CHAT_FORMAT_SEED_OSS,
|
||||
COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
||||
COMMON_CHAT_FORMAT_APERTUS,
|
||||
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
|
||||
COMMON_CHAT_FORMAT_GLM_4_5,
|
||||
COMMON_CHAT_FORMAT_MINIMAX_M2,
|
||||
COMMON_CHAT_FORMAT_KIMI_K2,
|
||||
|
||||
@@ -2726,11 +2726,29 @@ bool fs_validate_filename(const std::string & filename) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
static std::wstring utf8_to_wstring(const std::string& str) {
|
||||
if (str.empty()) {
|
||||
return std::wstring();
|
||||
}
|
||||
|
||||
int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
|
||||
|
||||
if (size <= 0) {
|
||||
return std::wstring();
|
||||
}
|
||||
|
||||
std::wstring wstr(size, 0);
|
||||
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size);
|
||||
|
||||
return wstr;
|
||||
}
|
||||
#endif
|
||||
|
||||
// returns true if successful, false otherwise
|
||||
bool fs_create_directory_with_parents(const std::string & path) {
|
||||
#ifdef _WIN32
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
||||
std::wstring wpath = converter.from_bytes(path);
|
||||
std::wstring wpath = utf8_to_wstring(path);
|
||||
|
||||
// if the path already exists, check whether it's a directory
|
||||
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
||||
@@ -3586,175 +3604,6 @@ bool llama_should_add_bos_token(const llama_model * model) {
|
||||
return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
|
||||
}
|
||||
|
||||
//
|
||||
// Chat template utils
|
||||
//
|
||||
//
|
||||
//bool llama_chat_verify_template(const struct llama_model* model, const std::string& tmpl, bool use_jinja) {
|
||||
// if (use_jinja) {
|
||||
// try {
|
||||
// auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
|
||||
// common_chat_inputs inputs;
|
||||
// inputs.messages = json::array({ {
|
||||
// {"role", "user"},
|
||||
// {"content", "test"},
|
||||
// } });
|
||||
// common_chat_params_init(chat_template, inputs);
|
||||
// return true;
|
||||
// }
|
||||
// catch (const std::exception& e) {
|
||||
// fprintf(stdout,"%s: failed to apply template: %s\n", __func__, e.what());
|
||||
// return false;
|
||||
// }
|
||||
// }
|
||||
// llama_chat_message chat[] = { {"user", "test"} };
|
||||
// const int res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||
// return res >= 0;
|
||||
//}
|
||||
|
||||
//std::string llama_chat_apply_template(const struct llama_model * model,
|
||||
// const common_chat_template& tmpl,
|
||||
// const std::vector<common_chat_msg> & msgs,
|
||||
// bool add_ass,
|
||||
// bool use_jinja) {
|
||||
// if (use_jinja) {
|
||||
// auto messages = json::array();
|
||||
// for (const auto& msg : msgs) {
|
||||
// messages.push_back({ {"role", msg.role}, {"content", msg.content} });
|
||||
// }
|
||||
// common_chat_inputs inputs;
|
||||
// inputs.messages = messages;
|
||||
// inputs.add_generation_prompt = add_ass;
|
||||
// return common_chat_params_init(tmpl, inputs).prompt;
|
||||
// }
|
||||
// int alloc_size = 0;
|
||||
// std::vector<llama_chat_message> chat;
|
||||
// for (auto & msg : msgs) {
|
||||
// chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
||||
// alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
||||
// }
|
||||
//
|
||||
// std::vector<char> buf(alloc_size);
|
||||
//
|
||||
// // run the first time to get the total output length
|
||||
// int32_t res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
// // error: chat template is not supported
|
||||
// if (res < 0) {
|
||||
// // if the custom "tmpl" is not supported, we throw an error
|
||||
// // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
||||
// throw std::runtime_error("this custom template is not supported");
|
||||
// }
|
||||
//
|
||||
// // if it turns out that our buffer is too small, we resize it
|
||||
// if ((size_t)res > buf.size()) {
|
||||
// buf.resize(res);
|
||||
// res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
// }
|
||||
//
|
||||
// std::string formatted_chat(buf.data(), res);
|
||||
// return formatted_chat;
|
||||
//}
|
||||
////
|
||||
//std::string llama_chat_format_single(const struct llama_model * model,
|
||||
// const common_chat_template& tmpl,
|
||||
// const std::vector<common_chat_msg> & past_msg,
|
||||
// const common_chat_msg & new_msg,
|
||||
// bool add_ass,
|
||||
// bool use_jinja) {
|
||||
// std::ostringstream ss;
|
||||
// auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja);
|
||||
// std::vector<common_chat_msg> chat_new(past_msg);
|
||||
// // if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||
// if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||
// ss << "\n";
|
||||
// };
|
||||
// // format chat with new_msg
|
||||
// chat_new.push_back(new_msg);
|
||||
// auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja);
|
||||
// // get the diff part
|
||||
// ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||
// return ss.str();
|
||||
//}
|
||||
|
||||
//std::string llama_chat_format_example(const struct llama_model * model, const common_chat_template& tmpl, bool use_jinja) {
|
||||
// std::vector<common_chat_msg> msgs = {
|
||||
// {"system", "You are a helpful assistant", {}},
|
||||
// {"user", "Hello", {}},
|
||||
// {"assistant", "Hi there", {}},
|
||||
// {"user", "How are you?", {}},
|
||||
// };
|
||||
// return llama_chat_apply_template(model, tmpl, msgs, true, use_jinja);
|
||||
//}
|
||||
//
|
||||
//#define CHATML_TEMPLATE_SRC \
|
||||
// "{%- for message in messages -%}\n" \
|
||||
// " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
|
||||
// "{%- endfor -%}\n" \
|
||||
// "{%- if add_generation_prompt -%}\n" \
|
||||
// " {{- '<|im_start|>assistant\n' -}}\n" \
|
||||
// "{%- endif -%}"
|
||||
//
|
||||
//common_chat_templates llama_chat_templates_from_model(const struct llama_model* model, const std::string& chat_template_override)
|
||||
//{
|
||||
// std::string default_template_src;
|
||||
// std::string template_tool_use_src;
|
||||
// bool has_explicit_template = !chat_template_override.empty();
|
||||
// if (chat_template_override.empty()) {
|
||||
// auto str = llama_model_chat_template(model, /* name */ nullptr);
|
||||
// if (str) {
|
||||
// default_template_src = str;
|
||||
// has_explicit_template = true;
|
||||
// }
|
||||
// str = llama_model_chat_template(model, /* name */ "tool_use");
|
||||
// if (str) {
|
||||
// template_tool_use_src = str;
|
||||
// has_explicit_template = true;
|
||||
// }
|
||||
// }
|
||||
// else {
|
||||
// default_template_src = chat_template_override;
|
||||
// }
|
||||
// if (default_template_src.empty() || default_template_src == "chatml") {
|
||||
// if (!template_tool_use_src.empty()) {
|
||||
// default_template_src = template_tool_use_src;
|
||||
// }
|
||||
// else {
|
||||
// default_template_src = CHATML_TEMPLATE_SRC;
|
||||
// }
|
||||
// }
|
||||
// auto vocab = llama_model_get_vocab(model);
|
||||
// const auto get_token = [&](llama_token token, const char* name, const char* jinja_variable_name) {
|
||||
// if (token == LLAMA_TOKEN_NULL) {
|
||||
// if (default_template_src.find(jinja_variable_name) != std::string::npos
|
||||
// || template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
||||
// fprintf(stdout, "%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
|
||||
// }
|
||||
// return std::string();
|
||||
// }
|
||||
// else {
|
||||
// return llama_token_to_piece(model, token, true);
|
||||
// }
|
||||
// };
|
||||
// auto token_bos = get_token(llama_token_bos_impl(*vocab), "BOS", "bos_token");
|
||||
// auto token_eos = get_token(llama_token_eos_impl(*vocab), "EOS", "eos_token");
|
||||
// try {
|
||||
// return {
|
||||
// has_explicit_template,
|
||||
// std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
|
||||
// template_tool_use_src.empty()
|
||||
// ? nullptr
|
||||
// : std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
|
||||
// };
|
||||
// }
|
||||
// catch (const std::exception& e) {
|
||||
// LOG("%s: failed to parse chat template: %s\n", __func__, e.what());
|
||||
// return {
|
||||
// has_explicit_template,
|
||||
// std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
|
||||
// nullptr,
|
||||
// };
|
||||
// }
|
||||
//}
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
|
||||
464
common/log.cpp
Normal file
464
common/log.cpp
Normal file
@@ -0,0 +1,464 @@
|
||||
#include "log.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#if defined(_WIN32)
|
||||
# include <io.h>
|
||||
# include <windows.h>
|
||||
# define isatty _isatty
|
||||
# define fileno _fileno
|
||||
#else
|
||||
# include <unistd.h>
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
int common_log_verbosity_thold = LOG_DEFAULT_LLAMA;
|
||||
|
||||
void common_log_set_verbosity_thold(int verbosity) {
|
||||
common_log_verbosity_thold = verbosity;
|
||||
}
|
||||
|
||||
// Auto-detect if colors should be enabled based on terminal and environment
|
||||
static bool common_log_should_use_colors_auto() {
|
||||
// Check NO_COLOR environment variable (https://no-color.org/)
|
||||
if (const char * no_color = std::getenv("NO_COLOR")) {
|
||||
if (no_color[0] != '\0') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check TERM environment variable
|
||||
if (const char * term = std::getenv("TERM")) {
|
||||
if (std::strcmp(term, "dumb") == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if stdout and stderr are connected to a terminal
|
||||
// We check both because log messages can go to either
|
||||
bool stdout_is_tty = isatty(fileno(stdout));
|
||||
bool stderr_is_tty = isatty(fileno(stderr));
|
||||
|
||||
return stdout_is_tty || stderr_is_tty;
|
||||
}
|
||||
|
||||
static int64_t t_us() {
|
||||
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
}
|
||||
|
||||
// colors
|
||||
enum common_log_col : int {
|
||||
COMMON_LOG_COL_DEFAULT = 0,
|
||||
COMMON_LOG_COL_BOLD,
|
||||
COMMON_LOG_COL_RED,
|
||||
COMMON_LOG_COL_GREEN,
|
||||
COMMON_LOG_COL_YELLOW,
|
||||
COMMON_LOG_COL_BLUE,
|
||||
COMMON_LOG_COL_MAGENTA,
|
||||
COMMON_LOG_COL_CYAN,
|
||||
COMMON_LOG_COL_WHITE,
|
||||
};
|
||||
|
||||
// disable colors by default
|
||||
static std::vector<const char *> g_col = {
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
};
|
||||
|
||||
struct common_log_entry {
|
||||
enum ggml_log_level level;
|
||||
|
||||
bool prefix;
|
||||
|
||||
int64_t timestamp;
|
||||
|
||||
std::vector<char> msg;
|
||||
|
||||
// signals the worker thread to stop
|
||||
bool is_end;
|
||||
|
||||
void print(FILE * file = nullptr) const {
|
||||
FILE * fcur = file;
|
||||
if (!fcur) {
|
||||
// stderr displays DBG messages only when their verbosity level is not higher than the threshold
|
||||
// these messages will still be logged to a file
|
||||
if (level == GGML_LOG_LEVEL_DEBUG && common_log_verbosity_thold < LOG_DEFAULT_DEBUG) {
|
||||
return;
|
||||
}
|
||||
|
||||
fcur = stdout;
|
||||
|
||||
if (level != GGML_LOG_LEVEL_NONE) {
|
||||
fcur = stderr;
|
||||
}
|
||||
}
|
||||
|
||||
if (level != GGML_LOG_LEVEL_NONE && level != GGML_LOG_LEVEL_CONT && prefix) {
|
||||
if (timestamp) {
|
||||
// [M.s.ms.us]
|
||||
fprintf(fcur, "%s%d.%02d.%03d.%03d%s ",
|
||||
g_col[COMMON_LOG_COL_BLUE],
|
||||
(int) (timestamp / 1000000 / 60),
|
||||
(int) (timestamp / 1000000 % 60),
|
||||
(int) (timestamp / 1000 % 1000),
|
||||
(int) (timestamp % 1000),
|
||||
g_col[COMMON_LOG_COL_DEFAULT]);
|
||||
}
|
||||
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_INFO: fprintf(fcur, "%sI %s", g_col[COMMON_LOG_COL_GREEN], g_col[COMMON_LOG_COL_DEFAULT]); break;
|
||||
case GGML_LOG_LEVEL_WARN: fprintf(fcur, "%sW %s", g_col[COMMON_LOG_COL_MAGENTA], "" ); break;
|
||||
case GGML_LOG_LEVEL_ERROR: fprintf(fcur, "%sE %s", g_col[COMMON_LOG_COL_RED], "" ); break;
|
||||
case GGML_LOG_LEVEL_DEBUG: fprintf(fcur, "%sD %s", g_col[COMMON_LOG_COL_YELLOW], "" ); break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(fcur, "%s", msg.data());
|
||||
|
||||
if (level == GGML_LOG_LEVEL_WARN || level == GGML_LOG_LEVEL_ERROR || level == GGML_LOG_LEVEL_DEBUG) {
|
||||
fprintf(fcur, "%s", g_col[COMMON_LOG_COL_DEFAULT]);
|
||||
}
|
||||
|
||||
fflush(fcur);
|
||||
}
|
||||
};
|
||||
|
||||
struct common_log {
|
||||
// default capacity - will be expanded if needed
|
||||
common_log() : common_log(256) {}
|
||||
|
||||
common_log(size_t capacity) {
|
||||
file = nullptr;
|
||||
prefix = false;
|
||||
timestamps = false;
|
||||
running = false;
|
||||
t_start = t_us();
|
||||
|
||||
// initial message size - will be expanded if longer messages arrive
|
||||
entries.resize(capacity);
|
||||
for (auto & entry : entries) {
|
||||
entry.msg.resize(256);
|
||||
}
|
||||
|
||||
head = 0;
|
||||
tail = 0;
|
||||
|
||||
resume();
|
||||
}
|
||||
|
||||
~common_log() {
|
||||
pause();
|
||||
if (file) {
|
||||
fclose(file);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex mtx;
|
||||
std::thread thrd;
|
||||
std::condition_variable cv;
|
||||
|
||||
FILE * file;
|
||||
|
||||
bool prefix;
|
||||
bool timestamps;
|
||||
bool running;
|
||||
|
||||
int64_t t_start;
|
||||
|
||||
// ring buffer of entries
|
||||
std::vector<common_log_entry> entries;
|
||||
size_t head;
|
||||
size_t tail;
|
||||
|
||||
// worker thread copies into this
|
||||
common_log_entry cur;
|
||||
|
||||
public:
|
||||
void add(enum ggml_log_level level, const char * fmt, va_list args) {
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
|
||||
if (!running) {
|
||||
// discard messages while the worker thread is paused
|
||||
return;
|
||||
}
|
||||
|
||||
auto & entry = entries[tail];
|
||||
|
||||
{
|
||||
// cannot use args twice, so make a copy in case we need to expand the buffer
|
||||
va_list args_copy;
|
||||
va_copy(args_copy, args);
|
||||
|
||||
#if 1
|
||||
const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args);
|
||||
if (n >= entry.msg.size()) {
|
||||
entry.msg.resize(n + 1);
|
||||
vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args_copy);
|
||||
}
|
||||
#else
|
||||
// hack for bolding arguments
|
||||
|
||||
std::stringstream ss;
|
||||
for (int i = 0; fmt[i] != 0; i++) {
|
||||
if (fmt[i] == '%') {
|
||||
ss << LOG_COL_BOLD;
|
||||
while (fmt[i] != ' ' && fmt[i] != ')' && fmt[i] != ']' && fmt[i] != 0) ss << fmt[i++];
|
||||
ss << LOG_COL_DEFAULT;
|
||||
if (fmt[i] == 0) break;
|
||||
}
|
||||
ss << fmt[i];
|
||||
}
|
||||
const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args);
|
||||
if (n >= entry.msg.size()) {
|
||||
entry.msg.resize(n + 1);
|
||||
vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy);
|
||||
}
|
||||
#endif
|
||||
va_end(args_copy);
|
||||
}
|
||||
|
||||
entry.level = level;
|
||||
entry.prefix = prefix;
|
||||
entry.timestamp = 0;
|
||||
if (timestamps) {
|
||||
entry.timestamp = t_us() - t_start;
|
||||
}
|
||||
entry.is_end = false;
|
||||
|
||||
tail = (tail + 1) % entries.size();
|
||||
if (tail == head) {
|
||||
// expand the buffer
|
||||
std::vector<common_log_entry> new_entries(2*entries.size());
|
||||
|
||||
size_t new_tail = 0;
|
||||
|
||||
do {
|
||||
new_entries[new_tail] = std::move(entries[head]);
|
||||
|
||||
head = (head + 1) % entries.size();
|
||||
new_tail = (new_tail + 1);
|
||||
} while (head != tail);
|
||||
|
||||
head = 0;
|
||||
tail = new_tail;
|
||||
|
||||
for (size_t i = tail; i < new_entries.size(); i++) {
|
||||
new_entries[i].msg.resize(256);
|
||||
}
|
||||
|
||||
entries = std::move(new_entries);
|
||||
}
|
||||
|
||||
cv.notify_one();
|
||||
}
|
||||
|
||||
void resume() {
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
|
||||
if (running) {
|
||||
return;
|
||||
}
|
||||
|
||||
running = true;
|
||||
|
||||
thrd = std::thread([this]() {
|
||||
while (true) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
cv.wait(lock, [this]() { return head != tail; });
|
||||
|
||||
cur = entries[head];
|
||||
|
||||
head = (head + 1) % entries.size();
|
||||
}
|
||||
|
||||
if (cur.is_end) {
|
||||
break;
|
||||
}
|
||||
|
||||
cur.print(); // stdout and stderr
|
||||
|
||||
if (file) {
|
||||
cur.print(file);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void pause() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
|
||||
if (!running) {
|
||||
return;
|
||||
}
|
||||
|
||||
running = false;
|
||||
|
||||
// push an entry to signal the worker thread to stop
|
||||
{
|
||||
auto & entry = entries[tail];
|
||||
entry.is_end = true;
|
||||
|
||||
tail = (tail + 1) % entries.size();
|
||||
}
|
||||
|
||||
cv.notify_one();
|
||||
}
|
||||
|
||||
thrd.join();
|
||||
}
|
||||
|
||||
void set_file(const char * path) {
|
||||
pause();
|
||||
|
||||
if (file) {
|
||||
fclose(file);
|
||||
}
|
||||
|
||||
if (path) {
|
||||
file = fopen(path, "w");
|
||||
} else {
|
||||
file = nullptr;
|
||||
}
|
||||
|
||||
resume();
|
||||
}
|
||||
|
||||
void set_colors(bool colors) {
|
||||
pause();
|
||||
|
||||
if (colors) {
|
||||
g_col[COMMON_LOG_COL_DEFAULT] = LOG_COL_DEFAULT;
|
||||
g_col[COMMON_LOG_COL_BOLD] = LOG_COL_BOLD;
|
||||
g_col[COMMON_LOG_COL_RED] = LOG_COL_RED;
|
||||
g_col[COMMON_LOG_COL_GREEN] = LOG_COL_GREEN;
|
||||
g_col[COMMON_LOG_COL_YELLOW] = LOG_COL_YELLOW;
|
||||
g_col[COMMON_LOG_COL_BLUE] = LOG_COL_BLUE;
|
||||
g_col[COMMON_LOG_COL_MAGENTA] = LOG_COL_MAGENTA;
|
||||
g_col[COMMON_LOG_COL_CYAN] = LOG_COL_CYAN;
|
||||
g_col[COMMON_LOG_COL_WHITE] = LOG_COL_WHITE;
|
||||
} else {
|
||||
for (size_t i = 0; i < g_col.size(); i++) {
|
||||
g_col[i] = "";
|
||||
}
|
||||
}
|
||||
|
||||
resume();
|
||||
}
|
||||
|
||||
void set_prefix(bool prefix) {
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
|
||||
this->prefix = prefix;
|
||||
}
|
||||
|
||||
void set_timestamps(bool timestamps) {
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
|
||||
this->timestamps = timestamps;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// public API
|
||||
//
|
||||
|
||||
struct common_log * common_log_init() {
|
||||
return new common_log;
|
||||
}
|
||||
|
||||
struct common_log * common_log_main() {
|
||||
static struct common_log log;
|
||||
static std::once_flag init_flag;
|
||||
std::call_once(init_flag, [&]() {
|
||||
// Set default to auto-detect colors
|
||||
log.set_colors(common_log_should_use_colors_auto());
|
||||
});
|
||||
|
||||
return &log;
|
||||
}
|
||||
|
||||
void common_log_pause(struct common_log * log) {
|
||||
log->pause();
|
||||
}
|
||||
|
||||
void common_log_resume(struct common_log * log) {
|
||||
log->resume();
|
||||
}
|
||||
|
||||
void common_log_free(struct common_log * log) {
|
||||
delete log;
|
||||
}
|
||||
|
||||
void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...) {
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
log->add(level, fmt, args);
|
||||
va_end(args);
|
||||
}
|
||||
|
||||
void common_log_set_file(struct common_log * log, const char * file) {
|
||||
log->set_file(file);
|
||||
}
|
||||
|
||||
void common_log_set_colors(struct common_log * log, log_colors colors) {
|
||||
if (colors == LOG_COLORS_AUTO) {
|
||||
log->set_colors(common_log_should_use_colors_auto());
|
||||
return;
|
||||
}
|
||||
|
||||
if (colors == LOG_COLORS_DISABLED) {
|
||||
log->set_colors(false);
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(colors == LOG_COLORS_ENABLED);
|
||||
log->set_colors(true);
|
||||
}
|
||||
|
||||
void common_log_set_prefix(struct common_log * log, bool prefix) {
|
||||
log->set_prefix(prefix);
|
||||
}
|
||||
|
||||
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
||||
log->set_timestamps(timestamps);
|
||||
}
|
||||
|
||||
static int common_get_verbosity(enum ggml_log_level level) {
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG;
|
||||
case GGML_LOG_LEVEL_INFO: return LOG_LEVEL_INFO;
|
||||
case GGML_LOG_LEVEL_WARN: return LOG_LEVEL_WARN;
|
||||
case GGML_LOG_LEVEL_ERROR: return LOG_LEVEL_ERROR;
|
||||
case GGML_LOG_LEVEL_CONT: return LOG_LEVEL_INFO; // same as INFO
|
||||
case GGML_LOG_LEVEL_NONE:
|
||||
default:
|
||||
return LOG_LEVEL_OUTPUT;
|
||||
}
|
||||
}
|
||||
|
||||
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||
auto verbosity = common_get_verbosity(level);
|
||||
if (verbosity <= common_log_verbosity_thold) {
|
||||
common_log_add(common_log_main(), level, "%s", text);
|
||||
}
|
||||
}
|
||||
130
common/log.h
130
common/log.h
@@ -1,4 +1,5 @@
|
||||
#pragma once
|
||||
#include "ggml.h" // for ggml_log_level
|
||||
#include <chrono>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
@@ -8,6 +9,124 @@
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
|
||||
|
||||
|
||||
|
||||
#define LOG_CLR_TO_EOL "\033[K\r"
|
||||
#define LOG_COL_DEFAULT "\033[0m"
|
||||
#define LOG_COL_BOLD "\033[1m"
|
||||
#define LOG_COL_RED "\033[31m"
|
||||
#define LOG_COL_GREEN "\033[32m"
|
||||
#define LOG_COL_YELLOW "\033[33m"
|
||||
#define LOG_COL_BLUE "\033[34m"
|
||||
#define LOG_COL_MAGENTA "\033[35m"
|
||||
#define LOG_COL_CYAN "\033[36m"
|
||||
#define LOG_COL_WHITE "\033[37m"
|
||||
|
||||
#ifndef __GNUC__
|
||||
# define LOG_ATTRIBUTE_FORMAT(...)
|
||||
#elif defined(__MINGW32__) && !defined(__clang__)
|
||||
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||
#else
|
||||
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||
#endif
|
||||
|
||||
#define LOG_LEVEL_DEBUG 4
|
||||
#define LOG_LEVEL_INFO 3
|
||||
#define LOG_LEVEL_WARN 2
|
||||
#define LOG_LEVEL_ERROR 1
|
||||
#define LOG_LEVEL_OUTPUT 0 // output data from tools
|
||||
|
||||
#define LOG_DEFAULT_DEBUG LOG_LEVEL_DEBUG
|
||||
#define LOG_DEFAULT_LLAMA LOG_LEVEL_INFO
|
||||
|
||||
enum log_colors {
|
||||
LOG_COLORS_AUTO = -1,
|
||||
LOG_COLORS_DISABLED = 0,
|
||||
LOG_COLORS_ENABLED = 1,
|
||||
};
|
||||
|
||||
// needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower
|
||||
// set via common_log_set_verbosity()
|
||||
extern int common_log_verbosity_thold;
|
||||
|
||||
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
|
||||
|
||||
void common_log_default_callback(enum ggml_log_level level, const char* text, void* user_data);
|
||||
|
||||
// the common_log uses an internal worker thread to print/write log messages
|
||||
// when the worker thread is paused, incoming log messages are discarded
|
||||
struct common_log;
|
||||
|
||||
struct common_log* common_log_init();
|
||||
struct common_log* common_log_main(); // singleton, automatically destroys itself on exit
|
||||
void common_log_pause(struct common_log* log); // pause the worker thread, not thread-safe
|
||||
void common_log_resume(struct common_log* log); // resume the worker thread, not thread-safe
|
||||
void common_log_free(struct common_log* log);
|
||||
|
||||
LOG_ATTRIBUTE_FORMAT(3, 4)
|
||||
void common_log_add(struct common_log* log, enum ggml_log_level level, const char* fmt, ...);
|
||||
|
||||
// defaults: file = NULL, colors = false, prefix = false, timestamps = false
|
||||
//
|
||||
// regular log output:
|
||||
//
|
||||
// ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34)
|
||||
// llm_load_tensors: ggml ctx size = 0.27 MiB
|
||||
// llm_load_tensors: offloading 32 repeating layers to GPU
|
||||
// llm_load_tensors: offloading non-repeating layers to GPU
|
||||
//
|
||||
// with prefix = true, timestamps = true, the log output will look like this:
|
||||
//
|
||||
// 0.00.035.060 D ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34)
|
||||
// 0.00.035.064 I llm_load_tensors: ggml ctx size = 0.27 MiB
|
||||
// 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU
|
||||
// 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU
|
||||
//
|
||||
// D - debug (stderr, V = LOG_DEFAULT_DEBUG)
|
||||
// I - info (stdout, V = LOG_DEFAULT_INFO)
|
||||
// W - warning (stderr, V = LOG_DEFAULT_WARN)
|
||||
// E - error (stderr, V = LOG_DEFAULT_ERROR)
|
||||
// O - output (stdout, V = LOG_DEFAULT_OUTPUT)
|
||||
//
|
||||
|
||||
void common_log_set_file(struct common_log* log, const char* file); // not thread-safe
|
||||
void common_log_set_colors(struct common_log* log, log_colors colors); // not thread-safe
|
||||
void common_log_set_prefix(struct common_log* log, bool prefix); // whether to output prefix to each log
|
||||
void common_log_set_timestamps(struct common_log* log, bool timestamps); // whether to output timestamps in the prefix
|
||||
|
||||
// helper macros for logging
|
||||
// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold
|
||||
//
|
||||
// for example:
|
||||
//
|
||||
// LOG_DBG("this is a debug message: %d\n", expensive_function());
|
||||
//
|
||||
// this will avoid calling expensive_function() if LOG_DEFAULT_DEBUG > common_log_verbosity_thold
|
||||
//
|
||||
|
||||
#define LOG_TMPL(level, verbosity, ...) \
|
||||
do { \
|
||||
if ((verbosity) <= common_log_verbosity_thold) { \
|
||||
common_log_add(common_log_main(), (level), __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, LOG_LEVEL_OUTPUT, __VA_ARGS__)
|
||||
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
|
||||
|
||||
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, LOG_LEVEL_INFO, __VA_ARGS__)
|
||||
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, LOG_LEVEL_WARN, __VA_ARGS__)
|
||||
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, LOG_LEVEL_INFO, __VA_ARGS__) // same as INFO
|
||||
|
||||
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
|
||||
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)
|
||||
#define LOG_ERRV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, verbosity, __VA_ARGS__)
|
||||
#define LOG_DBGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, verbosity, __VA_ARGS__)
|
||||
#define LOG_CNTV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_CONT, verbosity, __VA_ARGS__)
|
||||
|
||||
// --------------------------------
|
||||
//
|
||||
// Basic usage:
|
||||
@@ -293,11 +412,11 @@ inline std::string log_filename_generator_impl(LogTriState multilog, const std::
|
||||
// Main LOG macro.
|
||||
// behaves like printf, and supports arguments the exact same way.
|
||||
//
|
||||
#if !defined(_MSC_VER) || defined(__clang__)
|
||||
#define LOG(...) LOG_IMPL(__VA_ARGS__, "")
|
||||
#else
|
||||
#define LOG(str, ...) LOG_IMPL("%s" str, "", ##__VA_ARGS__, "")
|
||||
#endif
|
||||
//#if !defined(_MSC_VER) || defined(__clang__)
|
||||
// #define LOG(...) LOG_IMPL(__VA_ARGS__, "")
|
||||
//#else
|
||||
// #define LOG(str, ...) LOG_IMPL("%s" str, "", ##__VA_ARGS__, "")
|
||||
//#endif
|
||||
|
||||
// Main TEE macro.
|
||||
// does the same as LOG
|
||||
@@ -721,3 +840,4 @@ inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch)
|
||||
#define LOG_DUMP_CMDLINE(...) // dummy stub
|
||||
|
||||
#endif // LOG_DISABLE_LOGS
|
||||
|
||||
|
||||
@@ -38,10 +38,29 @@ namespace fs = std::filesystem;
|
||||
|
||||
// NOTE: this is copied from common.cpp to avoid linking with libcommon
|
||||
// returns true if successful, false otherwise
|
||||
|
||||
#ifdef _WIN32
|
||||
static std::wstring utf8_to_wstring(const std::string& str) {
|
||||
if (str.empty()) {
|
||||
return std::wstring();
|
||||
}
|
||||
|
||||
int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
|
||||
|
||||
if (size <= 0) {
|
||||
return std::wstring();
|
||||
}
|
||||
|
||||
std::wstring wstr(size, 0);
|
||||
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size);
|
||||
|
||||
return wstr;
|
||||
}
|
||||
#endif
|
||||
|
||||
static bool fs_create_directory_with_parents(const std::string& path) {
|
||||
#ifdef _WIN32
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
||||
std::wstring wpath = converter.from_bytes(path);
|
||||
std::wstring wpath = utf8_to_wstring(path);
|
||||
|
||||
// if the path already exists, check whether it's a directory
|
||||
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
||||
|
||||
@@ -12,8 +12,15 @@ endif()
|
||||
|
||||
set(TARGET_SRCS
|
||||
server.cpp
|
||||
utils.hpp
|
||||
httplib.h
|
||||
server-task.cpp
|
||||
server-task.h
|
||||
server-queue.cpp
|
||||
server-queue.h
|
||||
server-common.cpp
|
||||
server-common.h
|
||||
server-context.cpp
|
||||
server-context.h
|
||||
)
|
||||
set(PUBLIC_ASSETS
|
||||
index.html.gz
|
||||
|
||||
@@ -1,99 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include <src/llama-impl.h>
|
||||
#include "common.h"
|
||||
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include <nlohmann/json.hpp>
|
||||
#include "base64.hpp"
|
||||
#include "mtmd.h"
|
||||
#include "mtmd-helper.h"
|
||||
#include "chat.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <random>
|
||||
#include <set>
|
||||
|
||||
// increase max payload length to allow use of larger context size
|
||||
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
||||
// increase backlog size to avoid connection resets for >> 1 slots
|
||||
#define CPPHTTPLIB_LISTEN_BACKLOG 512
|
||||
// increase max URI length to handle longer prompts in query string
|
||||
#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768
|
||||
// disable Nagle's algorithm
|
||||
#define CPPHTTPLIB_TCP_NODELAY true
|
||||
#include "httplib.h"
|
||||
|
||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||
enum error_type {
|
||||
ERROR_TYPE_INVALID_REQUEST,
|
||||
ERROR_TYPE_AUTHENTICATION,
|
||||
ERROR_TYPE_SERVER,
|
||||
ERROR_TYPE_NOT_FOUND,
|
||||
ERROR_TYPE_PERMISSION,
|
||||
ERROR_TYPE_UNAVAILABLE, // custom error
|
||||
ERROR_TYPE_NOT_SUPPORTED, // custom error
|
||||
};
|
||||
|
||||
extern bool server_verbose;
|
||||
extern bool server_log_json;
|
||||
|
||||
#ifndef SERVER_VERBOSE
|
||||
#define SERVER_VERBOSE 1
|
||||
#endif
|
||||
|
||||
#if SERVER_VERBOSE != 1
|
||||
#define LOG_VERBOSE(MSG, ...)
|
||||
#else
|
||||
#define LOG_VERBOSE(MSG, ...) \
|
||||
do \
|
||||
{ \
|
||||
if (server_verbose) \
|
||||
{ \
|
||||
server_log("VERB", __func__, __LINE__, MSG, __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#define LOG_ERROR( MSG, ...) server_log("ERR", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
#define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
#include "server-common.h"
|
||||
|
||||
using raw_buffer = std::vector<uint8_t>;
|
||||
|
||||
static inline void server_log(const char * level, const char * function, int line, const char * message, const json & extra);
|
||||
|
||||
template <typename T>
|
||||
static T json_value(const json & body, const std::string & key, const T & default_value) {
|
||||
// Fallback null to default value
|
||||
if (body.contains(key) && !body.at(key).is_null()) {
|
||||
try {
|
||||
return body.at(key);
|
||||
} catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const& err) {
|
||||
std::stringstream ss;
|
||||
ss << "Wrong type supplied for parameter '" << key << "'. Expected '" << json(default_value).type_name() << "', using default value: "<< err.what();
|
||||
LOG_WARNING(ss.str().c_str(), body);
|
||||
return default_value;
|
||||
}
|
||||
} else {
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
|
||||
// thin wrapper around common_grammar_trigger with (de)serialization functions
|
||||
struct server_grammar_trigger {
|
||||
common_grammar_trigger value;
|
||||
|
||||
server_grammar_trigger() = default;
|
||||
server_grammar_trigger(const common_grammar_trigger& value) : value(value) {}
|
||||
server_grammar_trigger(const json& in) {
|
||||
server_grammar_trigger::server_grammar_trigger(const json& in) {
|
||||
value.type = (common_grammar_trigger_type)in.at("type").get<int>();
|
||||
value.value = in.at("value").get<std::string>();
|
||||
if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
|
||||
@@ -101,7 +11,7 @@ struct server_grammar_trigger {
|
||||
}
|
||||
}
|
||||
|
||||
json to_json() const {
|
||||
json server_grammar_trigger::to_json() const {
|
||||
json out{
|
||||
{"type", (int)value.type},
|
||||
{"value", value.value},
|
||||
@@ -111,9 +21,9 @@ struct server_grammar_trigger {
|
||||
}
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
static inline void server_log(const char * level, const char * function, int line, const char * message, const json & extra) {
|
||||
|
||||
void server_log(const char* level, const char* function, int line, const char* message, const json& extra) {
|
||||
std::stringstream ss_tid;
|
||||
ss_tid << std::this_thread::get_id();
|
||||
json log = json{
|
||||
@@ -134,7 +44,8 @@ static inline void server_log(const char * level, const char * function, int lin
|
||||
}
|
||||
|
||||
printf("%s\n", log.dump(-1, ' ', false, json::error_handler_t::replace).c_str());
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
char buf[1024];
|
||||
snprintf(buf, 1024, "%4s [%24s] %s", level, function, message);
|
||||
|
||||
@@ -159,20 +70,12 @@ static inline void server_log(const char * level, const char * function, int lin
|
||||
// chat template utils
|
||||
//
|
||||
|
||||
//
|
||||
// base64 utils (TODO: move to common in the future)
|
||||
//
|
||||
|
||||
static const std::string base64_chars =
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"abcdefghijklmnopqrstuvwxyz"
|
||||
"0123456789+/";
|
||||
|
||||
static inline bool is_base64(uint8_t c) {
|
||||
bool is_base64(uint8_t c) {
|
||||
return (isalnum(c) || (c == '+') || (c == '/'));
|
||||
}
|
||||
|
||||
static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string) {
|
||||
std::vector<uint8_t> base64_decode(const std::string& encoded_string) {
|
||||
int i = 0;
|
||||
int j = 0;
|
||||
int in_ = 0;
|
||||
@@ -228,7 +131,7 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
|
||||
// random string / id
|
||||
//
|
||||
|
||||
static std::string random_string() {
|
||||
std::string random_string() {
|
||||
static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
|
||||
|
||||
std::random_device rd;
|
||||
@@ -243,33 +146,33 @@ static std::string random_string() {
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string gen_chatcmplid() {
|
||||
std::string gen_chatcmplid() {
|
||||
std::stringstream chatcmplid;
|
||||
chatcmplid << "chatcmpl-" << random_string();
|
||||
|
||||
return chatcmplid.str();
|
||||
}
|
||||
|
||||
static std::string gen_tool_call_id() {
|
||||
std::string gen_tool_call_id() {
|
||||
return random_string();
|
||||
}
|
||||
|
||||
//
|
||||
// other common utils
|
||||
//
|
||||
static float get_slot_similarity(size_t lcp, size_t prompt_length, size_t cache_length) {
|
||||
float get_slot_similarity(size_t lcp, size_t prompt_length, size_t cache_length) {
|
||||
float sim = float(lcp) * 2 / (prompt_length + cache_length);
|
||||
return sim;
|
||||
}
|
||||
|
||||
static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
|
||||
size_t common_part(const std::vector<llama_token>& a, const std::vector<llama_token>& b) {
|
||||
size_t i;
|
||||
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
static size_t common_part(const std::string & a, const std::string & b) {
|
||||
size_t common_part(const std::string& a, const std::string& b) {
|
||||
size_t i;
|
||||
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
||||
|
||||
@@ -279,7 +182,7 @@ static size_t common_part(const std::string & a, const std::string & b) {
|
||||
// return the last index of character that can form a valid string
|
||||
// if the last character is potentially cut in half, return the index before the cut
|
||||
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
|
||||
static size_t validate_utf8(const std::string& text) {
|
||||
size_t validate_utf8(const std::string& text) {
|
||||
size_t len = text.size();
|
||||
if (len == 0) return 0;
|
||||
|
||||
@@ -319,8 +222,12 @@ static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string tokens_to_str(llama_context* ctx, const llama_tokens& tokens) {
|
||||
return tokens_to_str(ctx, tokens.begin(), tokens.end());
|
||||
}
|
||||
|
||||
// format incomplete utf-8 multibyte character for output
|
||||
static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
|
||||
std::string tokens_to_output_formatted_string(const llama_context* ctx, const llama_token token) {
|
||||
std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
|
||||
|
||||
// if the size is 1 and first bit is 1, meaning it's a partial character
|
||||
@@ -335,10 +242,6 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx,
|
||||
return out;
|
||||
}
|
||||
|
||||
struct common_prefix {
|
||||
size_t first = 0;
|
||||
size_t second = 0;
|
||||
};
|
||||
|
||||
common_prefix common_prefix_add(const common_prefix& a, const common_prefix& b) {
|
||||
common_prefix prefix;
|
||||
@@ -494,19 +397,8 @@ common_prefix find_common_text_token_prefix(const llama_context * ctx, const lla
|
||||
}
|
||||
|
||||
|
||||
struct completion_token_output {
|
||||
llama_token tok;
|
||||
std::string text_to_send;
|
||||
float prob;
|
||||
|
||||
struct prob_info {
|
||||
llama_token tok;
|
||||
std::string txt;
|
||||
float prob;
|
||||
};
|
||||
std::vector<prob_info> probs;
|
||||
|
||||
json to_json(bool post_sampling_probs) const {
|
||||
json completion_token_output::to_json(bool post_sampling_probs) const {
|
||||
json probs_for_token = json::array();
|
||||
for (const auto& p : probs) {
|
||||
std::string txt(p.txt);
|
||||
@@ -524,12 +416,12 @@ struct completion_token_output {
|
||||
return probs_for_token;
|
||||
}
|
||||
|
||||
static float logarithm(float x) {
|
||||
float completion_token_output::logarithm(float x) {
|
||||
// nlohmann::json converts -inf to null, so we need to prevent that
|
||||
return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
|
||||
}
|
||||
|
||||
static std::vector<unsigned char> str_to_bytes(const std::string& str) {
|
||||
std::vector<unsigned char> completion_token_output::str_to_bytes(const std::string& str) {
|
||||
std::vector<unsigned char> bytes;
|
||||
for (unsigned char c : str) {
|
||||
bytes.push_back(c);
|
||||
@@ -538,7 +430,7 @@ struct completion_token_output {
|
||||
}
|
||||
|
||||
|
||||
static json probs_vector_to_json(const std::vector<completion_token_output>& probs, bool post_sampling_probs) {
|
||||
json completion_token_output::probs_vector_to_json(const std::vector<completion_token_output>& probs, bool post_sampling_probs) {
|
||||
json out = json::array();
|
||||
for (const auto& p : probs) {
|
||||
std::string txt(p.text_to_send);
|
||||
@@ -559,10 +451,10 @@ struct completion_token_output {
|
||||
}
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// convert a vector of completion_token_output to json
|
||||
static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> & probs) {
|
||||
json probs_vector_to_json(const llama_context* ctx, const std::vector<completion_token_output>& probs) {
|
||||
json out = json::array();
|
||||
|
||||
for (const auto& prob : probs) {
|
||||
@@ -586,7 +478,7 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
|
||||
return out;
|
||||
}
|
||||
|
||||
static bool server_sent_event(httplib::DataSink& sink, const json& data) {
|
||||
bool server_sent_event(httplib::DataSink& sink, const json& data) {
|
||||
const std::string str =
|
||||
"data: " +
|
||||
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
@@ -597,7 +489,7 @@ static bool server_sent_event(httplib::DataSink& sink, const json& data) {
|
||||
return sink.write(str.c_str(), str.size());
|
||||
}
|
||||
|
||||
static bool server_sent_anthropic_event(httplib::DataSink& sink, const json& data) {
|
||||
bool server_sent_anthropic_event(httplib::DataSink& sink, const json& data) {
|
||||
const std::string str =
|
||||
(data.contains("event") && data.contains("data")) ?
|
||||
("event: " + data.at("event").get<std::string>() + "\n" +
|
||||
@@ -613,7 +505,7 @@ static bool server_sent_anthropic_event(httplib::DataSink& sink, const json& dat
|
||||
// OAI utils
|
||||
//
|
||||
// used by /completions endpoint
|
||||
static json oaicompat_chat_params_parse(const json& body) {
|
||||
json oaicompat_chat_params_parse(const json& body) {
|
||||
json llama_params;
|
||||
|
||||
if (!body.contains("prompt")) {
|
||||
@@ -664,19 +556,9 @@ static json oaicompat_chat_params_parse(const json& body) {
|
||||
return llama_params;
|
||||
}
|
||||
|
||||
struct oaicompat_parser_options {
|
||||
bool use_jinja;
|
||||
bool prefill_assistant;
|
||||
common_reasoning_format reasoning_format;
|
||||
std::map<std::string, std::string> chat_template_kwargs;
|
||||
common_chat_templates* tmpls;
|
||||
bool allow_image;
|
||||
bool allow_audio;
|
||||
bool enable_thinking = true;
|
||||
};
|
||||
|
||||
// used by /chat/completions endpoint
|
||||
static json oaicompat_chat_params_parse(
|
||||
json oaicompat_chat_params_parse(
|
||||
const struct llama_model* model,
|
||||
json& body, /* openai api json semantics */
|
||||
const oaicompat_parser_options& opt,
|
||||
@@ -906,7 +788,8 @@ static json oaicompat_chat_params_parse(
|
||||
for (auto& p : last_message.content_parts) {
|
||||
chat_params.prompt += p.text;
|
||||
}
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
chat_params.prompt += last_message.content;
|
||||
}
|
||||
}
|
||||
@@ -942,7 +825,8 @@ static json oaicompat_chat_params_parse(
|
||||
throw std::runtime_error("logprobs is not supported with tools + stream");
|
||||
}
|
||||
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
|
||||
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
|
||||
}
|
||||
else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
|
||||
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
|
||||
}
|
||||
|
||||
@@ -960,7 +844,7 @@ static json oaicompat_chat_params_parse(
|
||||
return llama_params;
|
||||
}
|
||||
|
||||
static json anthropic_params_from_json(
|
||||
json anthropic_params_from_json(
|
||||
const struct llama_model* model,
|
||||
const json& body_in, /* anthropic messages api json semantics */
|
||||
const oaicompat_parser_options& opt,
|
||||
@@ -971,14 +855,16 @@ static json anthropic_params_from_json(
|
||||
|
||||
if (body.contains("stop_sequences")) {
|
||||
llama_params["stop"] = body.at("stop_sequences");
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
llama_params["stop"] = json::array();
|
||||
}
|
||||
|
||||
// handle max_tokens (required in Anthropic, but we're permissive)
|
||||
if (!body.contains("max_tokens")) {
|
||||
llama_params["n_predict"] = 4096;
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
llama_params["n_predict"] = body.at("max_tokens");
|
||||
}
|
||||
|
||||
@@ -1010,7 +896,8 @@ static json anthropic_params_from_json(
|
||||
|
||||
if (system_param.is_string()) {
|
||||
system_content = system_param.get<std::string>();
|
||||
} else if (system_param.is_array()) {
|
||||
}
|
||||
else if (system_param.is_array()) {
|
||||
for (const auto& block : system_param) {
|
||||
if (json_value(block, "type", std::string()) == "text") {
|
||||
system_content += json_value(block, "text", std::string());
|
||||
@@ -1064,7 +951,8 @@ static json anthropic_params_from_json(
|
||||
|
||||
if (type == "text") {
|
||||
converted_content.push_back(block);
|
||||
} else if (type == "image") {
|
||||
}
|
||||
else if (type == "image") {
|
||||
json source = json_value(block, "source", json::object());
|
||||
std::string source_type = json_value(source, "type", std::string());
|
||||
|
||||
@@ -1078,7 +966,8 @@ static json anthropic_params_from_json(
|
||||
{"url", "data:" + media_type + ";base64," + data}
|
||||
}}
|
||||
});
|
||||
} else if (source_type == "url") {
|
||||
}
|
||||
else if (source_type == "url") {
|
||||
std::string url = json_value(source, "url", std::string());
|
||||
converted_content.push_back({
|
||||
{"type", "image_url"},
|
||||
@@ -1087,7 +976,8 @@ static json anthropic_params_from_json(
|
||||
}}
|
||||
});
|
||||
}
|
||||
} else if (type == "tool_use") {
|
||||
}
|
||||
else if (type == "tool_use") {
|
||||
tool_calls.push_back({
|
||||
{"id", json_value(block, "id", std::string())},
|
||||
{"type", "function"},
|
||||
@@ -1097,14 +987,16 @@ static json anthropic_params_from_json(
|
||||
}}
|
||||
});
|
||||
has_tool_calls = true;
|
||||
} else if (type == "tool_result") {
|
||||
}
|
||||
else if (type == "tool_result") {
|
||||
std::string tool_use_id = json_value(block, "tool_use_id", std::string());
|
||||
|
||||
auto result_content = json_value(block, "content", json());
|
||||
std::string result_text;
|
||||
if (result_content.is_string()) {
|
||||
result_text = result_content.get<std::string>();
|
||||
} else if (result_content.is_array()) {
|
||||
}
|
||||
else if (result_content.is_array()) {
|
||||
for (const auto& c : result_content) {
|
||||
if (json_value(c, "type", std::string()) == "text") {
|
||||
result_text += json_value(c, "text", std::string());
|
||||
@@ -1125,7 +1017,8 @@ static json anthropic_params_from_json(
|
||||
json new_msg = { {"role", role} };
|
||||
if (!converted_content.empty()) {
|
||||
new_msg["content"] = converted_content;
|
||||
} else if (has_tool_calls) {
|
||||
}
|
||||
else if (has_tool_calls) {
|
||||
new_msg["content"] = "";
|
||||
}
|
||||
if (!tool_calls.empty()) {
|
||||
@@ -1136,12 +1029,14 @@ static json anthropic_params_from_json(
|
||||
for (const auto& tool_msg : tool_results) {
|
||||
oai_messages.push_back(tool_msg);
|
||||
}
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
if (!converted_content.empty() || has_tool_calls) {
|
||||
json new_msg = { {"role", role} };
|
||||
if (!converted_content.empty()) {
|
||||
new_msg["content"] = converted_content;
|
||||
} else if (has_tool_calls) {
|
||||
}
|
||||
else if (has_tool_calls) {
|
||||
new_msg["content"] = "";
|
||||
}
|
||||
if (!tool_calls.empty()) {
|
||||
@@ -1176,9 +1071,11 @@ static json anthropic_params_from_json(
|
||||
std::string type = json_value(tc, "type", std::string());
|
||||
if (type == "auto") {
|
||||
oai_tool_choice = "auto";
|
||||
} else if (type == "any") {
|
||||
}
|
||||
else if (type == "any") {
|
||||
oai_tool_choice = "required";
|
||||
} else if (type == "tool") {
|
||||
}
|
||||
else if (type == "tool") {
|
||||
oai_tool_choice = "required";
|
||||
}
|
||||
}
|
||||
@@ -1218,19 +1115,24 @@ static json anthropic_params_from_json(
|
||||
raw_buffer data;
|
||||
data.insert(data.end(), res.second.begin(), res.second.end());
|
||||
out_files.push_back(data);
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Failed to download image");
|
||||
}
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
// try to decode base64 image
|
||||
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
|
||||
if (parts.size() != 2) {
|
||||
throw std::runtime_error("Invalid image_url.url value");
|
||||
} else if (!string_starts_with(parts[0], "data:image/")) {
|
||||
}
|
||||
else if (!string_starts_with(parts[0], "data:image/")) {
|
||||
throw std::runtime_error("Invalid image_url.url format: " + parts[0]);
|
||||
} else if (!string_ends_with(parts[0], "base64")) {
|
||||
}
|
||||
else if (!string_ends_with(parts[0], "base64")) {
|
||||
throw std::runtime_error("image_url.url must be base64 encoded");
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
auto base64_data = parts[1];
|
||||
auto decoded_data = base64_decode(base64_data);
|
||||
out_files.push_back(decoded_data);
|
||||
@@ -1241,7 +1143,8 @@ static json anthropic_params_from_json(
|
||||
p["type"] = "text";
|
||||
p["text"] = mtmd_default_marker();
|
||||
p.erase("image_url");
|
||||
} else if (type == "input_audio") {
|
||||
}
|
||||
else if (type == "input_audio") {
|
||||
if (!opt.allow_audio) {
|
||||
throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||
}
|
||||
@@ -1296,9 +1199,11 @@ static json anthropic_params_from_json(
|
||||
auto enable_thinking_kwarg = json_value(inputs.chat_template_kwargs, "enable_thinking", std::string(""));
|
||||
if (enable_thinking_kwarg == "true") {
|
||||
inputs.enable_thinking = true;
|
||||
} else if (enable_thinking_kwarg == "false") {
|
||||
}
|
||||
else if (enable_thinking_kwarg == "false") {
|
||||
inputs.enable_thinking = false;
|
||||
} else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') {
|
||||
}
|
||||
else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') {
|
||||
throw std::runtime_error("invalid type for \"enable_thinking\" (expected boolean, got string)");
|
||||
}
|
||||
|
||||
@@ -1332,7 +1237,8 @@ static json anthropic_params_from_json(
|
||||
for (auto& p : last_message.content_parts) {
|
||||
chat_params.prompt += p.text;
|
||||
}
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
chat_params.prompt += last_message.content;
|
||||
}
|
||||
}
|
||||
@@ -1379,7 +1285,7 @@ static json anthropic_params_from_json(
|
||||
// tokenizer and input processing utils
|
||||
//
|
||||
|
||||
static bool json_is_array_of_numbers(const json& data) {
|
||||
bool json_is_array_of_numbers(const json& data) {
|
||||
if (data.is_array()) {
|
||||
for (const auto& e : data) {
|
||||
if (!e.is_number_integer()) {
|
||||
@@ -1392,7 +1298,7 @@ static bool json_is_array_of_numbers(const json& data) {
|
||||
}
|
||||
|
||||
// is array having BOTH numbers & strings?
|
||||
static bool json_is_array_of_mixed_numbers_strings(const json& data) {
|
||||
bool json_is_array_of_mixed_numbers_strings(const json& data) {
|
||||
bool seen_string = false;
|
||||
bool seen_number = false;
|
||||
if (data.is_array()) {
|
||||
@@ -1408,7 +1314,7 @@ static bool json_is_array_of_mixed_numbers_strings(const json& data) {
|
||||
}
|
||||
|
||||
// does array have any individual integers/tokens?
|
||||
static bool json_is_array_and_contains_numbers(const json& data) {
|
||||
bool json_is_array_and_contains_numbers(const json& data) {
|
||||
if (data.is_array()) {
|
||||
for (const auto& e : data) {
|
||||
if (e.is_number_integer()) {
|
||||
@@ -1421,7 +1327,7 @@ static bool json_is_array_and_contains_numbers(const json& data) {
|
||||
}
|
||||
|
||||
// get value by path(key1 / key2)
|
||||
static json json_get_nested_values(const std::vector<std::string>& paths, const json& js) {
|
||||
json json_get_nested_values(const std::vector<std::string>& paths, const json& js) {
|
||||
json result = json::object();
|
||||
|
||||
for (const std::string& path : paths) {
|
||||
@@ -1449,7 +1355,7 @@ static json json_get_nested_values(const std::vector<std::string>& paths, const
|
||||
* - only string, example: "string"
|
||||
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
|
||||
*/
|
||||
static std::vector<llama_token> tokenize_mixed(const llama_vocab* vocab, const json& json_prompt, bool add_special, bool parse_special) {
|
||||
std::vector<llama_token> tokenize_mixed(const llama_vocab* vocab, const json& json_prompt, bool add_special, bool parse_special) {
|
||||
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
||||
// or the first element of the json_prompt array is a string.
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
@@ -1488,19 +1394,19 @@ static std::vector<llama_token> tokenize_mixed(const llama_vocab* vocab, const j
|
||||
return prompt_tokens;
|
||||
}
|
||||
|
||||
static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
|
||||
json format_tokenizer_response(const std::vector<llama_token>& tokens) {
|
||||
return json{
|
||||
{"tokens", tokens}
|
||||
};
|
||||
}
|
||||
|
||||
static json format_detokenized_response(const std::string & content) {
|
||||
json format_detokenized_response(const std::string& content) {
|
||||
return json{
|
||||
{"content", content}
|
||||
};
|
||||
}
|
||||
|
||||
static json format_error_response(const std::string & message, const enum error_type type) {
|
||||
json format_error_response(const std::string& message, const enum error_type type) {
|
||||
std::string type_str;
|
||||
int code = 500;
|
||||
switch (type) {
|
||||
@@ -1540,12 +1446,8 @@ static json format_error_response(const std::string & message, const enum error_
|
||||
};
|
||||
}
|
||||
|
||||
struct token_probabilities {
|
||||
float sampled_token_p;
|
||||
std::vector<llama_token_data> cur;
|
||||
};
|
||||
|
||||
static token_probabilities get_token_probabilities(llama_context * ctx, int idx, llama_token sampled_token_id, int n_sorted) {
|
||||
token_probabilities get_token_probabilities(llama_context* ctx, int idx, llama_token sampled_token_id, int n_sorted) {
|
||||
const auto* logits = llama_get_logits_ith(ctx, idx);
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
n_sorted = std::min(n_sorted, n_vocab);
|
||||
@@ -1584,54 +1486,17 @@ static token_probabilities get_token_probabilities(llama_context * ctx, int idx,
|
||||
* server_tokens is a helper to manage the input tokens and image for the server.
|
||||
* it is made this way to simplify the logic of KV cache management.
|
||||
*/
|
||||
struct server_tokens {
|
||||
bool has_mtmd = false;
|
||||
|
||||
private: // disallow accessing these members directly, risking out-of-sync
|
||||
|
||||
// map a **start** index in tokens to the image chunk
|
||||
// note: the order need to be in-sync with tokens
|
||||
std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
|
||||
|
||||
// list of tokens
|
||||
// if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
|
||||
// otherwise, it is a normal text token
|
||||
// note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
|
||||
// note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
|
||||
llama_tokens tokens;
|
||||
|
||||
// for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
|
||||
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
|
||||
// idx 0 1 2 3 4 5 6 7 8 9 10
|
||||
// pos 0 1 2 3 4 5 5 5 7 7 7
|
||||
// map_idx_to_media will contain: {5, img0}, {8, img1}
|
||||
|
||||
public:
|
||||
server_tokens() = default;
|
||||
~server_tokens() = default;
|
||||
|
||||
// Prevent copying
|
||||
server_tokens(const server_tokens&) = delete;
|
||||
server_tokens& operator=(const server_tokens&) = delete;
|
||||
|
||||
// Allow moving (usually implicitly generated if members are movable)
|
||||
server_tokens(server_tokens&&) = default;
|
||||
server_tokens& operator=(server_tokens&&) = default;
|
||||
|
||||
// Allow accessing elements using [] operator
|
||||
llama_token operator[](size_t index) { return tokens[index]; }
|
||||
const llama_token& operator[](size_t index) const { return tokens[index]; }
|
||||
|
||||
server_tokens(mtmd::input_chunks& mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) {
|
||||
server_tokens::server_tokens(mtmd::input_chunks& mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) {
|
||||
for (size_t i = 0; i < mtmd_chunks.size(); ++i) {
|
||||
push_back(mtmd_chunks[i]);
|
||||
}
|
||||
}
|
||||
|
||||
server_tokens(const llama_tokens& tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
|
||||
server_tokens::server_tokens(const llama_tokens& tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
|
||||
}
|
||||
|
||||
llama_pos pos_next() const {
|
||||
llama_pos server_tokens::pos_next() const {
|
||||
if (!has_mtmd) {
|
||||
return tokens.size();
|
||||
}
|
||||
@@ -1647,7 +1512,7 @@ public:
|
||||
}
|
||||
|
||||
// for debugging
|
||||
std::string str() const {
|
||||
std::string server_tokens::str() const {
|
||||
std::ostringstream oss;
|
||||
oss << "tokens: ";
|
||||
for (size_t idx = 0; idx < tokens.size(); ++idx) {
|
||||
@@ -1668,7 +1533,7 @@ public:
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
const mtmd::input_chunk_ptr& find_chunk(size_t idx) const {
|
||||
const mtmd::input_chunk_ptr& server_tokens::find_chunk(size_t idx) const {
|
||||
auto it = map_idx_to_media.find(idx);
|
||||
if (it != map_idx_to_media.end()) {
|
||||
return it->second;
|
||||
@@ -1676,7 +1541,7 @@ public:
|
||||
throw std::runtime_error("Chunk not found");
|
||||
}
|
||||
|
||||
void push_back(llama_token tok) {
|
||||
void server_tokens::push_back(llama_token tok) {
|
||||
if (tok == LLAMA_TOKEN_NULL) {
|
||||
throw std::runtime_error("Invalid token");
|
||||
}
|
||||
@@ -1684,7 +1549,7 @@ public:
|
||||
}
|
||||
|
||||
// will create a copy of the chunk if it contains non-text data
|
||||
void push_back(const mtmd_input_chunk* chunk) {
|
||||
void server_tokens::push_back(const mtmd_input_chunk* chunk) {
|
||||
auto type = mtmd_input_chunk_get_type(chunk);
|
||||
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||
GGML_ASSERT(has_mtmd);
|
||||
@@ -1709,7 +1574,7 @@ public:
|
||||
}
|
||||
|
||||
// appends server tokens, updates the media map. copies media chunks.
|
||||
void push_back(server_tokens& tokens) {
|
||||
void server_tokens::push_back(server_tokens& tokens) {
|
||||
size_t start_idx = size();
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
push_back(tokens[i]);
|
||||
@@ -1727,66 +1592,66 @@ public:
|
||||
}
|
||||
|
||||
// for compatibility with context shift and prompt truncation
|
||||
void insert(const std::vector<llama_token>& inp_tokens) {
|
||||
void server_tokens::insert(const std::vector<llama_token>& inp_tokens) {
|
||||
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
|
||||
tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end());
|
||||
}
|
||||
|
||||
// for compatibility with context shift and prompt truncation
|
||||
void resize(size_t size) {
|
||||
void server_tokens::resize(size_t size) {
|
||||
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
|
||||
tokens.resize(size);
|
||||
}
|
||||
|
||||
llama_token * data() {
|
||||
llama_token* server_tokens::data() {
|
||||
return tokens.data();
|
||||
}
|
||||
|
||||
llama_tokens::iterator begin() {
|
||||
llama_tokens::iterator server_tokens::begin() {
|
||||
return tokens.begin();
|
||||
}
|
||||
|
||||
llama_tokens::iterator end() {
|
||||
llama_tokens::iterator server_tokens::end() {
|
||||
return tokens.end();
|
||||
}
|
||||
|
||||
llama_tokens::const_iterator cbegin() {
|
||||
llama_tokens::const_iterator server_tokens::cbegin() {
|
||||
return tokens.cbegin();
|
||||
}
|
||||
|
||||
llama_tokens::const_iterator cend() {
|
||||
llama_tokens::const_iterator server_tokens::cend() {
|
||||
return tokens.cend();
|
||||
}
|
||||
|
||||
llama_tokens tokens_data() {
|
||||
llama_tokens server_tokens::tokens_data() {
|
||||
return tokens;
|
||||
}
|
||||
|
||||
// for compatibility with speculative decoding, ctx shift, slot save/load
|
||||
const std::vector<llama_token>& get_text_tokens() const {
|
||||
const std::vector<llama_token>& server_tokens::get_text_tokens() const {
|
||||
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
|
||||
return tokens;
|
||||
}
|
||||
|
||||
// for compatibility with speculative decoding
|
||||
void set_token(llama_pos pos, llama_token id) {
|
||||
void server_tokens::set_token(llama_pos pos, llama_token id) {
|
||||
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
|
||||
tokens[pos] = id;
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
size_t server_tokens::size() const {
|
||||
return tokens.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
bool server_tokens::empty() const {
|
||||
return tokens.empty();
|
||||
}
|
||||
|
||||
void clear() {
|
||||
void server_tokens::clear() {
|
||||
tokens.clear();
|
||||
}
|
||||
|
||||
void keep_first(size_t n) {
|
||||
void server_tokens::keep_first(size_t n) {
|
||||
GGML_ASSERT(n <= tokens.size());
|
||||
if (has_mtmd) {
|
||||
if (n == tokens.size()) {
|
||||
@@ -1819,7 +1684,7 @@ public:
|
||||
tokens.resize(n);
|
||||
}
|
||||
|
||||
std::string detokenize(const llama_context* ctx, bool special) const {
|
||||
std::string server_tokens::detokenize(const llama_context* ctx, bool special) const {
|
||||
llama_tokens text_tokens;
|
||||
text_tokens.reserve(tokens.size());
|
||||
for (const auto& t : tokens) {
|
||||
@@ -1830,7 +1695,7 @@ public:
|
||||
return llama_detokenize(ctx, text_tokens, special);
|
||||
}
|
||||
|
||||
std::string detokenize(const llama_context* ctx, bool special, size_t start, size_t length) const {
|
||||
std::string server_tokens::detokenize(const llama_context* ctx, bool special, size_t start, size_t length) const {
|
||||
std::string str;
|
||||
if (tokens.size() <= start || length == 0) {
|
||||
return str;
|
||||
@@ -1852,7 +1717,7 @@ public:
|
||||
return llama_detokenize(ctx, text_tokens, special);
|
||||
}
|
||||
|
||||
size_t find_n_from_tokens(const llama_context* ctx, const server_tokens& b, bool special,
|
||||
size_t server_tokens::find_n_from_tokens(const llama_context* ctx, const server_tokens& b, bool special,
|
||||
size_t start, const size_t length) {
|
||||
std::string str = detokenize(ctx, special, start, length);
|
||||
std::vector<size_t> tmp;
|
||||
@@ -1860,7 +1725,7 @@ public:
|
||||
return n;
|
||||
}
|
||||
|
||||
size_t get_common_prefix_exact(const server_tokens& b) const {
|
||||
size_t server_tokens::get_common_prefix_exact(const server_tokens& b) const {
|
||||
const size_t max_idx = std::min(tokens.size(), b.tokens.size());
|
||||
|
||||
if (!has_mtmd) {
|
||||
@@ -1909,7 +1774,7 @@ public:
|
||||
}
|
||||
|
||||
|
||||
common_prefix get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact = false) const {
|
||||
common_prefix server_tokens::get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact) const {
|
||||
common_prefix token_prefix;
|
||||
|
||||
size_t n = get_common_prefix_exact(b); // strict token match as a starting point
|
||||
@@ -1962,7 +1827,8 @@ public:
|
||||
prefix.first += n_tok_a;
|
||||
prefix.second += n_tok_a;
|
||||
token_prefix = common_prefix_add(prefix, token_prefix);
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
// do no include image token prefix
|
||||
// only return text token prefix
|
||||
token_prefix = common_prefix_add(prefix, token_prefix);
|
||||
@@ -1985,7 +1851,7 @@ public:
|
||||
|
||||
// take first n tokens of tokens list a
|
||||
// find the common prefix between a and b
|
||||
common_prefix get_common_prefix_first_n(const llama_context* ctx, const server_tokens& b, size_t n, bool exact = false) const {
|
||||
common_prefix server_tokens::get_common_prefix_first_n(const llama_context* ctx, const server_tokens& b, size_t n, bool exact) const {
|
||||
// not work for mtmd
|
||||
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
|
||||
auto tokens = get_text_tokens();
|
||||
@@ -1998,7 +1864,7 @@ public:
|
||||
}
|
||||
|
||||
// make sure all text tokens are within the vocab range
|
||||
bool validate(const struct llama_context* ctx) const {
|
||||
bool server_tokens::validate(const struct llama_context* ctx) const {
|
||||
const llama_model* model = llama_get_model(ctx);
|
||||
const llama_vocab* vocab = llama_model_get_vocab(model);
|
||||
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
|
||||
@@ -2023,7 +1889,7 @@ public:
|
||||
}
|
||||
|
||||
// encode and decode the image chunk
|
||||
int32_t process_chunk(
|
||||
int32_t server_tokens::process_chunk(
|
||||
llama_context* ctx,
|
||||
mtmd_context* mctx,
|
||||
size_t idx,
|
||||
@@ -2055,7 +1921,7 @@ public:
|
||||
}
|
||||
|
||||
// Keep the first n_keep and remove n_discard tokens from tokens
|
||||
void discard_n_tokens(int32_t n_keep, int32_t n_discard) {
|
||||
void server_tokens::discard_n_tokens(int32_t n_keep, int32_t n_discard) {
|
||||
if (n_discard <= 0 || n_keep + n_discard >= size()) {
|
||||
return;
|
||||
}
|
||||
@@ -2072,7 +1938,7 @@ public:
|
||||
}
|
||||
|
||||
// Similarity between prompt and cached
|
||||
float get_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const {
|
||||
float server_tokens::get_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep, int n_discard) const {
|
||||
GGML_ASSERT(n_keep >= 0 && n_discard >= 0);
|
||||
float sim_cur = 0;
|
||||
if (n_keep == 0 && n_discard == 0) {
|
||||
@@ -2090,7 +1956,7 @@ public:
|
||||
}
|
||||
|
||||
// Similarity between common part and cache
|
||||
float get_cached_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const {
|
||||
float server_tokens::get_cached_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep, int n_discard) const {
|
||||
GGML_ASSERT(n_keep >= 0 && n_discard >= 0);
|
||||
float sim_cur = 0;
|
||||
if (n_keep == 0 && n_discard == 0) {
|
||||
@@ -2106,10 +1972,10 @@ public:
|
||||
}
|
||||
return sim_cur;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Computes FNV-1a hash of the data
|
||||
static std::string fnv_hash(const uint8_t * data, size_t len) {
|
||||
std::string fnv_hash(const uint8_t* data, size_t len) {
|
||||
const uint64_t fnv_prime = 0x100000001b3ULL;
|
||||
uint64_t hash = 0xcbf29ce484222325ULL;
|
||||
|
||||
@@ -2120,7 +1986,7 @@ static std::string fnv_hash(const uint8_t * data, size_t len) {
|
||||
return std::to_string(hash);
|
||||
}
|
||||
|
||||
static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files) {
|
||||
server_tokens process_mtmd_prompt(mtmd_context* mctx, std::string prompt, std::vector<raw_buffer> files) {
|
||||
mtmd::bitmaps bitmaps;
|
||||
for (auto& file : files) {
|
||||
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size()));
|
||||
@@ -2163,7 +2029,7 @@ static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt
|
||||
* - "prompt": [12, 34, "string", 56, 78]
|
||||
* - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
|
||||
*/
|
||||
static server_tokens tokenize_input_subprompt(const llama_vocab* vocab, mtmd_context* mctx, const json& json_prompt, bool add_special, bool parse_special) {
|
||||
server_tokens tokenize_input_subprompt(const llama_vocab* vocab, mtmd_context* mctx, const json& json_prompt, bool add_special, bool parse_special) {
|
||||
constexpr char JSON_STRING_PROMPT_KEY[] = "prompt_string";
|
||||
constexpr char JSON_MTMD_DATA_KEY[] = "multimodal_data";
|
||||
const bool has_mtmd = mctx != nullptr;
|
||||
@@ -2214,7 +2080,7 @@ static server_tokens tokenize_input_subprompt(const llama_vocab* vocab, mtmd_con
|
||||
* - "prompt": [[12, 34, 56], [78, 90, 12]]
|
||||
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}]
|
||||
*/
|
||||
static std::vector<server_tokens> tokenize_input_prompts(const llama_vocab* vocab, mtmd_context* mctx, const json& json_prompt, bool add_special, bool parse_special) {
|
||||
std::vector<server_tokens> tokenize_input_prompts(const llama_vocab* vocab, mtmd_context* mctx, const json& json_prompt, bool add_special, bool parse_special) {
|
||||
std::vector<server_tokens> result;
|
||||
if (json_prompt.is_array() && !json_is_array_and_contains_numbers(json_prompt)) {
|
||||
result.reserve(json_prompt.size());
|
||||
@@ -2231,7 +2097,7 @@ static std::vector<server_tokens> tokenize_input_prompts(const llama_vocab* voca
|
||||
return result;
|
||||
}
|
||||
// Assuming raw_buffer has .data() and .size() members
|
||||
inline void print_files_info(const std::vector<raw_buffer>& files) {
|
||||
void print_files_info(const std::vector<raw_buffer>& files) {
|
||||
for (size_t i = 0; i < files.size(); ++i) {
|
||||
const auto& file = files[i];
|
||||
std::cout << "File " << i << ": Size = " << file.size() << " bytes\n";
|
||||
@@ -2246,7 +2112,7 @@ inline void print_files_info(const std::vector<raw_buffer>& files) {
|
||||
}
|
||||
}
|
||||
|
||||
inline bool prompt_cache_equal(llama_context* ctx, const server_tokens& cache_tokens,
|
||||
bool prompt_cache_equal(llama_context* ctx, const server_tokens& cache_tokens,
|
||||
const server_tokens& prompt_tokens, size_t start, const common_prefix& prefix) {
|
||||
std::string common_cache = cache_tokens.detokenize(ctx, true, start, prefix.first);
|
||||
std::string common_prompt = prompt_tokens.detokenize(ctx, true, start, prefix.second);
|
||||
455
examples/server/server-common.h
Normal file
455
examples/server/server-common.h
Normal file
@@ -0,0 +1,455 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include <src/llama-impl.h>
|
||||
#include "chat.h"
|
||||
#include "mtmd.h"
|
||||
#include "mtmd-helper.h"
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cinttypes>
|
||||
|
||||
|
||||
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "base64.hpp"
|
||||
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <random>
|
||||
#include <set>
|
||||
|
||||
// increase max payload length to allow use of larger context size
|
||||
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
||||
// increase backlog size to avoid connection resets for >> 1 slots
|
||||
#define CPPHTTPLIB_LISTEN_BACKLOG 512
|
||||
// increase max URI length to handle longer prompts in query string
|
||||
#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768
|
||||
// disable Nagle's algorithm
|
||||
#define CPPHTTPLIB_TCP_NODELAY true
|
||||
#include "httplib.h"
|
||||
|
||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
#define SLT_CNT(slot, fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
|
||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
|
||||
|
||||
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
|
||||
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||
enum error_type {
|
||||
ERROR_TYPE_INVALID_REQUEST,
|
||||
ERROR_TYPE_AUTHENTICATION,
|
||||
ERROR_TYPE_SERVER,
|
||||
ERROR_TYPE_NOT_FOUND,
|
||||
ERROR_TYPE_PERMISSION,
|
||||
ERROR_TYPE_UNAVAILABLE, // custom error
|
||||
ERROR_TYPE_NOT_SUPPORTED, // custom error
|
||||
};
|
||||
|
||||
extern bool server_verbose;
|
||||
extern bool server_log_json;
|
||||
|
||||
#ifndef SERVER_VERBOSE
|
||||
#define SERVER_VERBOSE 1
|
||||
#endif
|
||||
|
||||
#if SERVER_VERBOSE != 1
|
||||
#define LOG_VERBOSE(MSG, ...)
|
||||
#else
|
||||
#define LOG_VERBOSE(MSG, ...) \
|
||||
do \
|
||||
{ \
|
||||
if (server_verbose) \
|
||||
{ \
|
||||
server_log("VERB", __func__, __LINE__, MSG, __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#define LOG_ERROR( MSG, ...) server_log("ERR", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
#define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
|
||||
using raw_buffer = std::vector<uint8_t>;
|
||||
|
||||
void server_log(const char* level, const char* function, int line, const char* message, const json& extra);
|
||||
|
||||
template <typename T>
|
||||
static T json_value(const json& body, const std::string& key, const T& default_value) {
|
||||
// Fallback null to default value
|
||||
if (body.contains(key) && !body.at(key).is_null()) {
|
||||
try {
|
||||
return body.at(key);
|
||||
}
|
||||
catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const& err) {
|
||||
std::stringstream ss;
|
||||
ss << "Wrong type supplied for parameter '" << key << "'. Expected '" << json(default_value).type_name() << "', using default value: " << err.what();
|
||||
LOG_WARNING(ss.str().c_str(), body);
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
else {
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
|
||||
// thin wrapper around common_grammar_trigger with (de)serialization functions
|
||||
struct server_grammar_trigger {
|
||||
common_grammar_trigger value;
|
||||
|
||||
server_grammar_trigger() = default;
|
||||
server_grammar_trigger(const common_grammar_trigger& value) : value(value) {}
|
||||
server_grammar_trigger(const json& in);
|
||||
|
||||
json to_json() const;
|
||||
};
|
||||
|
||||
|
||||
//
|
||||
// chat template utils
|
||||
//
|
||||
|
||||
//
|
||||
// base64 utils (TODO: move to common in the future)
|
||||
//
|
||||
|
||||
static const std::string base64_chars =
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"abcdefghijklmnopqrstuvwxyz"
|
||||
"0123456789+/";
|
||||
|
||||
bool is_base64(uint8_t c);
|
||||
|
||||
std::vector<uint8_t> base64_decode(const std::string& encoded_string);
|
||||
|
||||
//
|
||||
// random string / id
|
||||
//
|
||||
|
||||
std::string random_string();
|
||||
|
||||
std::string gen_chatcmplid();
|
||||
|
||||
std::string gen_tool_call_id();
|
||||
|
||||
//
|
||||
// other common utils
|
||||
//
|
||||
float get_slot_similarity(size_t lcp, size_t prompt_length, size_t cache_length);
|
||||
|
||||
size_t common_part(const std::vector<llama_token>& a, const std::vector<llama_token>& b);
|
||||
|
||||
size_t common_part(const std::string& a, const std::string& b);
|
||||
|
||||
// return the last index of character that can form a valid string
|
||||
// if the last character is potentially cut in half, return the index before the cut
|
||||
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
|
||||
size_t validate_utf8(const std::string& text);
|
||||
|
||||
// TODO: reuse llama_detokenize
|
||||
|
||||
std::string tokens_to_str(llama_context* ctx, const llama_tokens& tokens);
|
||||
|
||||
// format incomplete utf-8 multibyte character for output
|
||||
std::string tokens_to_output_formatted_string(const llama_context* ctx, const llama_token token);
|
||||
|
||||
struct common_prefix {
|
||||
size_t first = 0;
|
||||
size_t second = 0;
|
||||
};
|
||||
|
||||
common_prefix common_prefix_add(const common_prefix& a, const common_prefix& b);
|
||||
|
||||
common_prefix find_common_string_prefix(const std::string& a_str, const std::string& b_str, const std::set<char>& ignore_set);
|
||||
|
||||
size_t find_n_tokens_from_string(const llama_context* ctx, const llama_tokens& a, const size_t max_size, size_t start,
|
||||
std::vector<size_t>& map);
|
||||
|
||||
std::string remove_with_set(std::string str, const std::set<char>& chars_to_remove);
|
||||
|
||||
common_prefix find_largest_common_number(const std::vector<size_t>& a_list, const std::vector<size_t>& b_list);
|
||||
|
||||
size_t find_n_tokens_from_string_with_ignore(const llama_context* ctx, const llama_tokens& a, const size_t max_size, size_t start, const std::set<char>& ignore_set,
|
||||
std::vector<size_t>& map);
|
||||
|
||||
common_prefix find_common_text_token_prefix(const llama_context* ctx, const llama_tokens& a, const llama_tokens& b,
|
||||
size_t start, bool exact);
|
||||
|
||||
struct completion_token_output {
|
||||
llama_token tok;
|
||||
std::string text_to_send;
|
||||
float prob;
|
||||
|
||||
struct prob_info {
|
||||
llama_token tok;
|
||||
std::string txt;
|
||||
float prob;
|
||||
};
|
||||
std::vector<prob_info> probs;
|
||||
|
||||
json to_json(bool post_sampling_probs) const;
|
||||
|
||||
static float logarithm(float x);
|
||||
|
||||
static std::vector<unsigned char> str_to_bytes(const std::string& str);
|
||||
|
||||
static json probs_vector_to_json(const std::vector<completion_token_output>& probs, bool post_sampling_probs);
|
||||
};
|
||||
|
||||
// convert a vector of completion_token_output to json
|
||||
json probs_vector_to_json(const llama_context* ctx, const std::vector<completion_token_output>& probs);
|
||||
|
||||
bool server_sent_event(httplib::DataSink& sink, const json& data);
|
||||
|
||||
bool server_sent_anthropic_event(httplib::DataSink& sink, const json& data);
|
||||
|
||||
//
|
||||
// OAI utils
|
||||
//
|
||||
// used by /completions endpoint
|
||||
json oaicompat_chat_params_parse(const json& body);
|
||||
|
||||
struct oaicompat_parser_options {
|
||||
bool use_jinja;
|
||||
bool prefill_assistant;
|
||||
common_reasoning_format reasoning_format;
|
||||
std::map<std::string, std::string> chat_template_kwargs;
|
||||
common_chat_templates* tmpls;
|
||||
bool allow_image;
|
||||
bool allow_audio;
|
||||
bool enable_thinking = true;
|
||||
};
|
||||
|
||||
// used by /chat/completions endpoint
|
||||
json oaicompat_chat_params_parse(
|
||||
const struct llama_model* model,
|
||||
json& body, /* openai api json semantics */
|
||||
const oaicompat_parser_options& opt,
|
||||
std::vector<raw_buffer>& out_files);
|
||||
|
||||
json anthropic_params_from_json(
|
||||
const struct llama_model* model,
|
||||
const json& body_in, /* anthropic messages api json semantics */
|
||||
const oaicompat_parser_options& opt,
|
||||
std::vector<raw_buffer>& out_files);
|
||||
|
||||
|
||||
//
|
||||
// tokenizer and input processing utils
|
||||
//
|
||||
|
||||
bool json_is_array_of_numbers(const json& data);
|
||||
|
||||
// is array having BOTH numbers & strings?
|
||||
bool json_is_array_of_mixed_numbers_strings(const json& data);
|
||||
|
||||
// does array have any individual integers/tokens?
|
||||
bool json_is_array_and_contains_numbers(const json& data);
|
||||
|
||||
// get value by path(key1 / key2)
|
||||
json json_get_nested_values(const std::vector<std::string>& paths, const json& js);
|
||||
|
||||
/**
|
||||
* this handles 2 cases:
|
||||
* - only string, example: "string"
|
||||
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
|
||||
*/
|
||||
std::vector<llama_token> tokenize_mixed(const llama_vocab* vocab, const json& json_prompt, bool add_special, bool parse_special);
|
||||
|
||||
json format_tokenizer_response(const std::vector<llama_token>& tokens);
|
||||
|
||||
json format_detokenized_response(const std::string& content);
|
||||
|
||||
json format_error_response(const std::string& message, const enum error_type type);
|
||||
|
||||
struct token_probabilities {
|
||||
float sampled_token_p;
|
||||
std::vector<llama_token_data> cur;
|
||||
};
|
||||
|
||||
token_probabilities get_token_probabilities(llama_context* ctx, int idx, llama_token sampled_token_id, int n_sorted);
|
||||
|
||||
/**
|
||||
* server_tokens is a helper to manage the input tokens and image for the server.
|
||||
* it is made this way to simplify the logic of KV cache management.
|
||||
*/
|
||||
struct server_tokens {
|
||||
bool has_mtmd = false;
|
||||
|
||||
private: // disallow accessing these members directly, risking out-of-sync
|
||||
|
||||
// map a **start** index in tokens to the image chunk
|
||||
// note: the order need to be in-sync with tokens
|
||||
std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
|
||||
|
||||
// list of tokens
|
||||
// if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
|
||||
// otherwise, it is a normal text token
|
||||
// note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
|
||||
// note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
|
||||
llama_tokens tokens;
|
||||
|
||||
// for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
|
||||
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
|
||||
// idx 0 1 2 3 4 5 6 7 8 9 10
|
||||
// pos 0 1 2 3 4 5 5 5 7 7 7
|
||||
// map_idx_to_media will contain: {5, img0}, {8, img1}
|
||||
|
||||
public:
|
||||
server_tokens() = default;
|
||||
~server_tokens() = default;
|
||||
|
||||
// Prevent copying
|
||||
server_tokens(const server_tokens&) = delete;
|
||||
server_tokens& operator=(const server_tokens&) = delete;
|
||||
|
||||
// Allow moving (usually implicitly generated if members are movable)
|
||||
server_tokens(server_tokens&&) = default;
|
||||
server_tokens& operator=(server_tokens&&) = default;
|
||||
|
||||
// Allow accessing elements using [] operator
|
||||
llama_token operator[](size_t index) { return tokens[index]; }
|
||||
const llama_token& operator[](size_t index) const { return tokens[index]; }
|
||||
|
||||
server_tokens(mtmd::input_chunks& mtmd_chunks, bool has_mtmd);
|
||||
|
||||
server_tokens(const llama_tokens& tokens, bool has_mtmd);
|
||||
|
||||
llama_pos pos_next() const;
|
||||
|
||||
// for debugging
|
||||
std::string str() const;
|
||||
|
||||
const mtmd::input_chunk_ptr& find_chunk(size_t idx) const;
|
||||
|
||||
void push_back(llama_token tok);
|
||||
|
||||
// will create a copy of the chunk if it contains non-text data
|
||||
void push_back(const mtmd_input_chunk* chunk);
|
||||
|
||||
// appends server tokens, updates the media map. copies media chunks.
|
||||
void push_back(server_tokens& tokens);
|
||||
|
||||
// for compatibility with context shift and prompt truncation
|
||||
void insert(const std::vector<llama_token>& inp_tokens);
|
||||
|
||||
// for compatibility with context shift and prompt truncation
|
||||
void resize(size_t size);
|
||||
|
||||
llama_token* data();
|
||||
|
||||
llama_tokens::iterator begin();
|
||||
|
||||
llama_tokens::iterator end();
|
||||
|
||||
llama_tokens::const_iterator cbegin();
|
||||
|
||||
llama_tokens::const_iterator cend();
|
||||
|
||||
llama_tokens tokens_data();
|
||||
|
||||
// for compatibility with speculative decoding, ctx shift, slot save/load
|
||||
const std::vector<llama_token>& get_text_tokens() const;
|
||||
|
||||
// for compatibility with speculative decoding
|
||||
void set_token(llama_pos pos, llama_token id);
|
||||
|
||||
size_t size() const;
|
||||
|
||||
bool empty() const;
|
||||
|
||||
void clear();
|
||||
|
||||
void keep_first(size_t n);
|
||||
|
||||
std::string detokenize(const llama_context* ctx, bool special) const;
|
||||
|
||||
std::string detokenize(const llama_context* ctx, bool special, size_t start, size_t length) const;
|
||||
|
||||
size_t find_n_from_tokens(const llama_context* ctx, const server_tokens& b, bool special,
|
||||
size_t start, const size_t length);
|
||||
|
||||
size_t get_common_prefix_exact(const server_tokens& b) const;
|
||||
|
||||
|
||||
common_prefix get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact = false) const;
|
||||
// take first n tokens of tokens list a
|
||||
// find the common prefix between a and b
|
||||
common_prefix get_common_prefix_first_n(const llama_context* ctx, const server_tokens& b, size_t n, bool exact = false) const;
|
||||
|
||||
// make sure all text tokens are within the vocab range
|
||||
bool validate(const struct llama_context* ctx) const;
|
||||
|
||||
// encode and decode the image chunk
|
||||
int32_t process_chunk(
|
||||
llama_context* ctx,
|
||||
mtmd_context* mctx,
|
||||
size_t idx,
|
||||
llama_pos pos,
|
||||
int32_t seq_id,
|
||||
size_t& n_tokens_out) const;
|
||||
|
||||
// Keep the first n_keep and remove n_discard tokens from tokens
|
||||
void discard_n_tokens(int32_t n_keep, int32_t n_discard);
|
||||
|
||||
// Similarity between prompt and cached
|
||||
float get_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const;
|
||||
|
||||
// Similarity between common part and cache
|
||||
float get_cached_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep = 0, int n_discard = 0) const;
|
||||
};
|
||||
|
||||
// Computes FNV-1a hash of the data
|
||||
std::string fnv_hash(const uint8_t* data, size_t len);
|
||||
|
||||
server_tokens process_mtmd_prompt(mtmd_context* mctx, std::string prompt, std::vector<raw_buffer> files);
|
||||
|
||||
/**
|
||||
* break the input "prompt" object into multiple prompt if needed, then tokenize them
|
||||
* use tokenize_input_prompts() if the input could be an array.
|
||||
* this supports these cases:
|
||||
* - "prompt": "string"
|
||||
* - "prompt": [12, 34, 56]
|
||||
* - "prompt": [12, 34, "string", 56, 78]
|
||||
* - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
|
||||
*/
|
||||
server_tokens tokenize_input_subprompt(const llama_vocab* vocab, mtmd_context* mctx, const json& json_prompt, bool add_special, bool parse_special);
|
||||
|
||||
/**
|
||||
* break the input "prompt" object into multiple prompt if needed, then tokenize them
|
||||
* this supports these cases:
|
||||
* - "prompt": "string"
|
||||
* - "prompt": [12, 34, 56]
|
||||
* - "prompt": [12, 34, "string", 56, 78]
|
||||
* - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
|
||||
* and multiple prompts (multi-tasks):
|
||||
* - "prompt": ["string1", "string2"]
|
||||
* - "prompt": ["string1", [12, 34, 56]]
|
||||
* - "prompt": [[12, 34, 56], [78, 90, 12]]
|
||||
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}]
|
||||
*/
|
||||
std::vector<server_tokens> tokenize_input_prompts(const llama_vocab* vocab, mtmd_context* mctx, const json& json_prompt, bool add_special, bool parse_special);
|
||||
|
||||
// Assuming raw_buffer has .data() and .size() members
|
||||
void print_files_info(const std::vector<raw_buffer>& files);
|
||||
|
||||
bool prompt_cache_equal(llama_context* ctx, const server_tokens& cache_tokens,
|
||||
const server_tokens& prompt_tokens, size_t start, const common_prefix& prefix);
|
||||
2763
examples/server/server-context.cpp
Normal file
2763
examples/server/server-context.cpp
Normal file
File diff suppressed because it is too large
Load Diff
316
examples/server/server-context.h
Normal file
316
examples/server/server-context.h
Normal file
@@ -0,0 +1,316 @@
|
||||
#include "server-task.h"
|
||||
#include "server-queue.h"
|
||||
#include "speculative.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
|
||||
|
||||
|
||||
enum slot_state {
|
||||
SLOT_STATE_IDLE,
|
||||
SLOT_STATE_PROCESSING,
|
||||
};
|
||||
|
||||
|
||||
|
||||
enum slot_command {
|
||||
SLOT_COMMAND_NONE,
|
||||
SLOT_COMMAND_LOAD_PROMPT,
|
||||
SLOT_COMMAND_RELEASE,
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
struct slot_params {
|
||||
bool stream = true;
|
||||
bool include_usage = false;
|
||||
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
|
||||
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
|
||||
bool timings_per_token = false;
|
||||
bool post_sampling_probs = false;
|
||||
json input_prefix;
|
||||
json input_suffix;
|
||||
|
||||
// speculative decoding parameters
|
||||
struct {
|
||||
int n_max = 16; // max drafted tokens
|
||||
int n_min = 0; // min drafted tokens to accept
|
||||
float p_min = 0.75f; // min probability required to accept a token in the draft
|
||||
} speculative;
|
||||
|
||||
// OAI-compat fields
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
|
||||
};
|
||||
|
||||
|
||||
struct server_slot {
|
||||
int id;
|
||||
int id_task = -1;
|
||||
int id_multi = -1;
|
||||
|
||||
struct slot_params params;
|
||||
|
||||
slot_state state = SLOT_STATE_IDLE;
|
||||
slot_command command = SLOT_COMMAND_NONE;
|
||||
|
||||
llama_context* ctx = nullptr;
|
||||
// used to determine the slot that has been used the longest
|
||||
int64_t t_last_used = -1;
|
||||
|
||||
std::unique_ptr<const server_task> task;
|
||||
|
||||
// generation props
|
||||
int32_t n_ctx = 0; // context size per slot
|
||||
int32_t n_past = 0;
|
||||
int32_t n_past_prompt = 0;
|
||||
int32_t n_decoded = 0;
|
||||
int32_t n_remaining = -1;
|
||||
int32_t n_discarded_prompt = 0;
|
||||
int32_t n_kept_prompt = 0;
|
||||
|
||||
int32_t i_batch = -1;
|
||||
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
||||
|
||||
int32_t n_prompt_tokens = 0;
|
||||
int32_t n_prompt_tokens_processed = 0;
|
||||
|
||||
json prompt; // can be either a string, array of strings or array of token ids
|
||||
|
||||
// when a task is submitted, we first tokenize the prompt and store it here
|
||||
server_tokens prompt_tokens;
|
||||
server_tokens cache_tokens;
|
||||
|
||||
std::string generated_text;
|
||||
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
common_chat_msg chat_msg;
|
||||
|
||||
bool infill = false;
|
||||
bool embedding = false;
|
||||
bool has_next_token = true;
|
||||
bool truncated = false;
|
||||
bool stopped_eos = false;
|
||||
bool stopped_word = false;
|
||||
bool stopped_limit = false;
|
||||
|
||||
bool oaicompat = false;
|
||||
|
||||
std::string oaicompat_model;
|
||||
std::string stopping_word;
|
||||
stop_type stop;
|
||||
|
||||
server_prompt server_cached_prompt;
|
||||
|
||||
void prompt_save(server_prompt_cache& prompt_cache) const;
|
||||
|
||||
void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens);
|
||||
|
||||
// sampling
|
||||
llama_token sampled;
|
||||
struct llama_sampling_params sparams;
|
||||
llama_sampling_context* ctx_sampling = nullptr;
|
||||
json json_schema;
|
||||
|
||||
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
std::vector<std::string> generated_tool_call_ids;
|
||||
|
||||
int32_t ga_i = 0; // group-attention state
|
||||
int32_t ga_n = 1; // group-attention factor
|
||||
int32_t ga_w = 512; // group-attention width
|
||||
|
||||
// multimodal
|
||||
mtmd_context* mctx = nullptr;
|
||||
|
||||
// speculative decoding
|
||||
struct llama_speculative* spec = nullptr;
|
||||
llama_context* ctx_dft = nullptr;
|
||||
llama_batch batch_spec = {};
|
||||
|
||||
// speculative decoding stats
|
||||
int32_t n_draft_total = 0; // Total draft tokens generated
|
||||
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
|
||||
|
||||
int32_t n_past_se = 0; // self-extend
|
||||
|
||||
// stats
|
||||
size_t n_sent_text = 0; // number of sent text character
|
||||
size_t n_sent_token_probs = 0;
|
||||
|
||||
int64_t t_start_process_prompt;
|
||||
int64_t t_start_generation;
|
||||
|
||||
double t_prompt_processing; // ms
|
||||
double t_token_generation; // ms
|
||||
|
||||
void reset();
|
||||
|
||||
bool has_budget(gpt_params& global_params);
|
||||
|
||||
bool available() const;
|
||||
|
||||
bool is_processing() const;
|
||||
|
||||
void add_token_string(const completion_token_output& token);
|
||||
|
||||
void release();
|
||||
|
||||
json get_formated_timings() const;
|
||||
|
||||
result_timings get_timings() const;
|
||||
|
||||
const common_chat_msg& update_chat_msg(std::vector<common_chat_msg_diff>& diffs);
|
||||
|
||||
size_t find_stopping_strings(const std::string& text, const size_t last_token_size, bool is_full_stop);
|
||||
|
||||
void print_timings() const;
|
||||
};
|
||||
|
||||
struct server_metrics {
|
||||
int64_t t_start = 0;
|
||||
|
||||
uint64_t n_prompt_tokens_processed_total = 0;
|
||||
uint64_t t_prompt_processing_total = 0;
|
||||
uint64_t n_tokens_predicted_total = 0;
|
||||
uint64_t t_tokens_generation_total = 0;
|
||||
|
||||
uint64_t n_prompt_tokens_processed = 0;
|
||||
uint64_t t_prompt_processing = 0;
|
||||
|
||||
uint64_t n_tokens_predicted = 0;
|
||||
uint64_t t_tokens_generation = 0;
|
||||
|
||||
void init();
|
||||
|
||||
void on_prompt_eval(const server_slot& slot);
|
||||
|
||||
void on_prediction(const server_slot& slot);
|
||||
|
||||
void reset_bucket();
|
||||
};
|
||||
|
||||
struct server_context {
|
||||
llama_model* model = nullptr;
|
||||
llama_context* ctx = nullptr;
|
||||
std::vector<llama_lora_adapter_container> lora_adapters;
|
||||
|
||||
gpt_params params;
|
||||
|
||||
llama_batch batch;
|
||||
|
||||
bool clean_kv_cache = true;
|
||||
bool add_bos_token = true;
|
||||
bool has_eos_token = false;
|
||||
|
||||
// multimodal
|
||||
mtmd_context* mctx = nullptr;
|
||||
|
||||
// For speculative decoding
|
||||
llama_model* model_draft = nullptr;
|
||||
llama_context* ctx_draft = nullptr;
|
||||
llama_context_params cparams_dft;
|
||||
|
||||
int32_t n_ctx; // total context for all clients / slots
|
||||
|
||||
// system prompt
|
||||
bool system_need_update = false;
|
||||
|
||||
std::string system_prompt;
|
||||
std::vector<llama_token> system_tokens;
|
||||
|
||||
// slots / clients
|
||||
std::vector<server_slot> slots;
|
||||
json default_generation_settings_for_props;
|
||||
|
||||
server_queue queue_tasks;
|
||||
server_response queue_results;
|
||||
|
||||
std::unique_ptr<server_prompt_cache> prompt_cache;
|
||||
|
||||
server_metrics metrics;
|
||||
|
||||
common_chat_templates_ptr chat_templates;
|
||||
oaicompat_parser_options oai_parser_opt;
|
||||
// Necessary similarity of prompt for slot selection
|
||||
float slot_prompt_similarity = 0.0f;
|
||||
int32_t cache_ram_n_min = 0;
|
||||
float cache_ram_similarity = 0.5f;
|
||||
|
||||
~server_context();
|
||||
|
||||
bool load_model(const gpt_params& params_);
|
||||
|
||||
void init();
|
||||
|
||||
std::vector<llama_token> tokenize(const json& json_prompt, bool add_special) const;
|
||||
|
||||
server_slot* get_slot_by_id(int id);
|
||||
|
||||
server_slot* get_available_slot(const server_task& task);
|
||||
|
||||
bool launch_slot_with_task(server_slot& slot, server_task& task);
|
||||
|
||||
void kv_cache_clear();
|
||||
|
||||
void system_prompt_update();
|
||||
|
||||
bool system_prompt_set(const std::string& sys_prompt);
|
||||
|
||||
bool process_token(completion_token_output& result, server_slot& slot);
|
||||
|
||||
void populate_token_probs(const server_slot& slot, completion_token_output& result, bool post_sampling, bool special, int idx);
|
||||
|
||||
json get_formated_generation(const server_slot& slot) const;
|
||||
|
||||
void send_error(const server_task& task, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER);
|
||||
|
||||
void send_error(const server_slot& slot, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER);
|
||||
|
||||
void send_error(const int id_task, const int id_multi, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER);
|
||||
|
||||
// if multimodal is enabled, send an error and return false
|
||||
bool ensure_no_mtmd(const int id_task);
|
||||
|
||||
void send_partial_response(server_slot& slot, completion_token_output tkn);
|
||||
|
||||
void send_final_response(server_slot& slot);
|
||||
|
||||
void send_embedding(const server_slot& slot, const llama_batch& batch);
|
||||
|
||||
void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding, server_tokens&& inputs);
|
||||
|
||||
void request_cancel(int id_task);
|
||||
|
||||
void split_multiprompt_task(int id_multi, server_task& multiprompt_task);
|
||||
|
||||
void process_single_task(server_task&& task);
|
||||
|
||||
void on_finish_multitask(const server_task_multi& multitask);
|
||||
|
||||
void print_tokens(const server_tokens& prompt, const server_tokens& cache, size_t start1 = 0, size_t start2 = 0, size_t length = 10);
|
||||
|
||||
void discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard);
|
||||
|
||||
// convert keep first few and discard next tokens in a to b
|
||||
void context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep,
|
||||
int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact = false);
|
||||
|
||||
void context_shift_prompt(llama_context* ctx, server_slot& slot, bool exact = false);
|
||||
|
||||
void update_slots();
|
||||
|
||||
json model_meta() const;
|
||||
};
|
||||
194
examples/server/server-queue.cpp
Normal file
194
examples/server/server-queue.cpp
Normal file
@@ -0,0 +1,194 @@
|
||||
#include "server-task.h"
|
||||
#include "server-queue.h"
|
||||
#include "server-common.h"
|
||||
|
||||
#include "log.h"
|
||||
#include <chrono>
|
||||
|
||||
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
#define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
|
||||
int server_queue::post(server_task task) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (task.id == -1) {
|
||||
task.id = id++;
|
||||
//LOG_VERBOSE("new task id", { {"new_id", task.id} });
|
||||
QUE_DBG("new task, id = %d\n", task.id);
|
||||
}
|
||||
queue_tasks.push_back(std::move(task));
|
||||
condition_tasks.notify_one();
|
||||
return task.id;
|
||||
}
|
||||
|
||||
void server_queue::defer(server_task&& task) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
queue_tasks_deferred.push_back(std::move(task));
|
||||
}
|
||||
|
||||
int server_queue::get_new_id() {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
int new_id = id++;
|
||||
//LOG_VERBOSE("new task id", { {"new_id", new_id} });
|
||||
QUE_DBG("new task, id = %d\n", id);
|
||||
return new_id;
|
||||
}
|
||||
|
||||
void server_queue::notify_slot_changed() {
|
||||
// move deferred tasks back to main loop
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
for (auto& task : queue_tasks_deferred) {
|
||||
queue_tasks.push_back(std::move(task));
|
||||
}
|
||||
queue_tasks_deferred.clear();
|
||||
}
|
||||
|
||||
void server_queue::on_new_task(std::function<void(server_task&&)> callback) {
|
||||
callback_new_task = std::move(callback);
|
||||
}
|
||||
|
||||
|
||||
void server_queue::start_loop() {
|
||||
running = true;
|
||||
|
||||
while (true) {
|
||||
LOG_VERBOSE("new task may arrive", {});
|
||||
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (queue_tasks.empty()) {
|
||||
lock.unlock();
|
||||
break;
|
||||
}
|
||||
server_task task = std::move(queue_tasks.front());
|
||||
queue_tasks.erase(queue_tasks.begin());
|
||||
lock.unlock();
|
||||
//LOG_VERBOSE("callback_new_task", { {"id_task", task.id} });
|
||||
callback_new_task(std::move(task));
|
||||
}
|
||||
|
||||
LOG_VERBOSE("update_multitasks", {});
|
||||
|
||||
// check if we have any finished multitasks
|
||||
auto queue_iterator = queue_multitasks.begin();
|
||||
while (queue_iterator != queue_multitasks.end()) {
|
||||
if (queue_iterator->subtasks_remaining.empty()) {
|
||||
// all subtasks done == multitask is done
|
||||
server_task_multi current_multitask = *queue_iterator;
|
||||
callback_finish_multitask(current_multitask);
|
||||
// remove this multitask
|
||||
queue_iterator = queue_multitasks.erase(queue_iterator);
|
||||
}
|
||||
else {
|
||||
++queue_iterator;
|
||||
}
|
||||
}
|
||||
|
||||
// all tasks in the current loop is processed, slots data is now ready
|
||||
LOG_VERBOSE("callback_update_slots", {});
|
||||
|
||||
callback_update_slots();
|
||||
|
||||
LOG_VERBOSE("wait for new task", {});
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (queue_tasks.empty()) {
|
||||
if (!running) {
|
||||
LOG_VERBOSE("ending start_loop", {});
|
||||
return;
|
||||
}
|
||||
condition_tasks.wait(lock, [&] {
|
||||
return (!queue_tasks.empty() || !running);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void server_queue::add_multitask(int id_multi, std::vector<int>& sub_ids) {
|
||||
std::lock_guard<std::mutex> lock(mutex_tasks);
|
||||
server_task_multi multi;
|
||||
multi.id = id_multi;
|
||||
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
|
||||
queue_multitasks.push_back(multi);
|
||||
}
|
||||
|
||||
|
||||
void server_queue::update_multitask(int id_multi, int id_sub, server_task_result& result) {
|
||||
std::lock_guard<std::mutex> lock(mutex_tasks);
|
||||
for (auto& multitask : queue_multitasks) {
|
||||
if (multitask.id == id_multi) {
|
||||
multitask.subtasks_remaining.erase(id_sub);
|
||||
multitask.results.push_back(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void server_response::add_waiting_task_id(int id_task) {
|
||||
//LOG_VERBOSE("waiting for task id", { {"id_task", id_task} });
|
||||
QUE_DBG("waiting for task id, id = %d\n", id_task);
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
waiting_task_ids.insert(id_task);
|
||||
}
|
||||
|
||||
void server_response::remove_waiting_task_id(int id_task) {
|
||||
//LOG_VERBOSE("remove waiting for task id", { {"id_task", id_task} });
|
||||
QUE_DBG("remove waiting for task id, id = %d\n", id_task);
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
waiting_task_ids.erase(id_task);
|
||||
}
|
||||
|
||||
|
||||
server_task_result server_response::recv(int id_task) {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
condition_results.wait(lock, [&] {
|
||||
return !queue_results.empty();
|
||||
});
|
||||
|
||||
for (int i = 0; i < (int)queue_results.size(); i++) {
|
||||
if (queue_results[i].id == id_task) {
|
||||
assert(queue_results[i].id_multi == -1);
|
||||
server_task_result res = queue_results[i];
|
||||
queue_results.erase(queue_results.begin() + i);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// should never reach here
|
||||
}
|
||||
|
||||
void server_response::send(server_task_result result) {
|
||||
//LOG_VERBOSE("send new result", { {"id_task", result.id} });
|
||||
QUE_DBG("send new result, id = %d\n", result.id);
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
for (const auto& id_task : waiting_task_ids) {
|
||||
// LOG_TEE("waiting task id %i \n", id_task);
|
||||
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
|
||||
if (result.id_multi == id_task) {
|
||||
//LOG_VERBOSE("callback_update_multitask", { {"id_task", id_task} });
|
||||
QUE_DBG("callback_update_multitask, id = %d\n", id_task);
|
||||
callback_update_multitask(id_task, result.id, result);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (result.id == id_task) {
|
||||
//LOG_VERBOSE("queue_results.push_back", { {"id_task", id_task} });
|
||||
QUE_DBG("queue_results.push_back, id = %d\n", id_task);
|
||||
queue_results.push_back(result);
|
||||
condition_results.notify_all();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
117
examples/server/server-queue.h
Normal file
117
examples/server/server-queue.h
Normal file
@@ -0,0 +1,117 @@
|
||||
#pragma once
|
||||
#include "server-task.h"
|
||||
|
||||
#include <condition_variable>
|
||||
#include <deque>
|
||||
#include <mutex>
|
||||
#include <unordered_set>
|
||||
|
||||
struct server_task_multi {
|
||||
int id = -1;
|
||||
|
||||
std::set<int> subtasks_remaining;
|
||||
std::vector<server_task_result> results;
|
||||
};
|
||||
|
||||
|
||||
struct server_queue {
|
||||
int id = 0;
|
||||
bool running;
|
||||
|
||||
// queues
|
||||
std::vector<server_task> queue_tasks;
|
||||
std::vector<server_task> queue_tasks_deferred;
|
||||
|
||||
std::vector<server_task_multi> queue_multitasks;
|
||||
|
||||
std::mutex mutex_tasks;
|
||||
std::condition_variable condition_tasks;
|
||||
|
||||
// callback functions
|
||||
std::function<void(server_task &&)> callback_new_task;
|
||||
std::function<void(server_task_multi &)> callback_finish_multitask;
|
||||
std::function<void(void)> callback_update_slots;
|
||||
|
||||
|
||||
// Add a new task to the end of the queue
|
||||
int post(server_task task);
|
||||
|
||||
// Add a new task, but defer until one slot is available
|
||||
void defer(server_task&& task);
|
||||
|
||||
// Get the next id for creating anew task
|
||||
int get_new_id();
|
||||
|
||||
// Register function to process a new task
|
||||
void on_new_task(std::function<void(server_task&&)> callback);
|
||||
|
||||
// Register function to process a multitask when it is finished
|
||||
void on_finish_multitask(std::function<void(server_task_multi&)> callback) {
|
||||
callback_finish_multitask = std::move(callback);
|
||||
}
|
||||
|
||||
// Register the function to be called when all slots data is ready to be processed
|
||||
void on_update_slots(std::function<void(void)> callback) {
|
||||
callback_update_slots = std::move(callback);
|
||||
}
|
||||
|
||||
// Call when the state of one slot is changed
|
||||
void notify_slot_changed();
|
||||
|
||||
// end the start_loop routine
|
||||
void terminate() {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
running = false;
|
||||
condition_tasks.notify_all();
|
||||
}
|
||||
|
||||
/**
|
||||
* Main loop consists of these steps:
|
||||
* - Wait until a new task arrives
|
||||
* - Process the task (i.e. maybe copy data into slot)
|
||||
* - Check if multitask is finished
|
||||
* - Update all slots
|
||||
*/
|
||||
void start_loop();
|
||||
|
||||
//
|
||||
// functions to manage multitasks
|
||||
//
|
||||
|
||||
// add a multitask by specifying the id of all subtask (subtask is a server_task)
|
||||
void add_multitask(int id_multi, std::vector<int>& sub_ids);
|
||||
|
||||
// updatethe remaining subtasks, while appending results to multitask
|
||||
void update_multitask(int id_multi, int id_sub, server_task_result& result);
|
||||
};
|
||||
|
||||
struct server_response {
|
||||
typedef std::function<void(int, int, server_task_result&)> callback_multitask_t;
|
||||
callback_multitask_t callback_update_multitask;
|
||||
|
||||
// for keeping track of all tasks waiting for the result
|
||||
std::set<int> waiting_task_ids;
|
||||
|
||||
// the main result queue
|
||||
std::vector<server_task_result> queue_results;
|
||||
|
||||
std::mutex mutex_results;
|
||||
std::condition_variable condition_results;
|
||||
|
||||
// add the id_task to the list of tasks waiting for response
|
||||
void add_waiting_task_id(int id_task);
|
||||
|
||||
// when the request is finished, we can remove task associated with it
|
||||
void remove_waiting_task_id(int id_task);
|
||||
|
||||
// This function blocks the thread until there is a response for this id_task
|
||||
server_task_result recv(int id_task);
|
||||
|
||||
// Register the function to update multitask
|
||||
void on_multitask_update(callback_multitask_t callback) {
|
||||
callback_update_multitask = std::move(callback);
|
||||
}
|
||||
|
||||
// Send a new result to a waiting id_task
|
||||
void send(server_task_result result);
|
||||
};
|
||||
816
examples/server/server-task.cpp
Normal file
816
examples/server/server-task.cpp
Normal file
@@ -0,0 +1,816 @@
|
||||
#include "server-task.h"
|
||||
|
||||
|
||||
json result_timings::to_json() const {
|
||||
json base = {
|
||||
{"prompt_n", prompt_n},
|
||||
{"prompt_ms", prompt_ms},
|
||||
{"prompt_per_token_ms", prompt_per_token_ms},
|
||||
{"prompt_per_second", prompt_per_second},
|
||||
|
||||
{"predicted_n", predicted_n},
|
||||
{"predicted_ms", predicted_ms},
|
||||
{"predicted_per_token_ms", predicted_per_token_ms},
|
||||
{"predicted_per_second", predicted_per_second},
|
||||
|
||||
{"n_ctx", n_ctx},
|
||||
{"n_past", n_past},
|
||||
};
|
||||
|
||||
if (draft_n > 0) {
|
||||
base["draft_n"] = draft_n;
|
||||
base["draft_n_accepted"] = draft_n_accepted;
|
||||
}
|
||||
|
||||
return base;
|
||||
}
|
||||
|
||||
|
||||
json server_task_result::to_json_final() {
|
||||
switch (oaicompat) {
|
||||
case OAICOMPAT_TYPE_NONE:
|
||||
return to_json_non_oaicompat_final();
|
||||
case OAICOMPAT_TYPE_COMPLETION:
|
||||
return to_json_oaicompat_final();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat_final();
|
||||
case OAICOMPAT_TYPE_ANTHROPIC:
|
||||
return stream ? to_json_anthropic_stream() : to_json_anthropic_final();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
}
|
||||
}
|
||||
|
||||
json server_task_result::to_json_partial() {
|
||||
switch (oaicompat) {
|
||||
case OAICOMPAT_TYPE_NONE:
|
||||
return to_json_non_oaicompat_partial();
|
||||
case OAICOMPAT_TYPE_COMPLETION:
|
||||
return to_json_oaicompat_partial();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
return to_json_oaicompat_chat_partial();
|
||||
case OAICOMPAT_TYPE_ANTHROPIC:
|
||||
return to_json_anthropic_partial();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
}
|
||||
}
|
||||
|
||||
json server_task_result::to_json_non_oaicompat_partial() {
|
||||
// non-OAI-compat JSON
|
||||
json res = json{
|
||||
{"index", index},
|
||||
{"content", content},
|
||||
{"tokens", tokens},
|
||||
{"stop", false},
|
||||
{"id_slot", id_multi},
|
||||
{"tokens_predicted", n_decoded},
|
||||
{"tokens_evaluated", n_prompt_tokens},
|
||||
};
|
||||
// populate the timings object when needed (usually for the last response or with timings_per_token enabled)
|
||||
if (timings.prompt_n > 0) {
|
||||
res.push_back({ "timings", timings.to_json() });
|
||||
}
|
||||
if (!probs_output.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_non_oaicompat_final() {
|
||||
json res = json{
|
||||
{"index", index},
|
||||
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||
{"tokens", stream ? std::vector<llama_token> {} : tokens},
|
||||
{"id_slot", id_multi},
|
||||
{"stop", true},
|
||||
{"model", oaicompat_model},
|
||||
{"tokens_predicted", n_decoded},
|
||||
{"tokens_evaluated", n_prompt_tokens},
|
||||
//{"generation_settings", default_generation_settings_for_props.to_json()},
|
||||
{"prompt", prompt},
|
||||
{"has_new_line", has_new_line},
|
||||
{"truncated", truncated},
|
||||
//{"stop_type", stop_type_to_str(STOP_TYPE_EOS)},
|
||||
{"stopping_word", stopping_word},
|
||||
{"tokens_cached", n_tokens_cached},
|
||||
{"timings", timings.to_json()},
|
||||
};
|
||||
if (!stream && !probs_output.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
||||
}
|
||||
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
|
||||
}
|
||||
|
||||
json server_task_result::to_json_oaicompat_partial() {
|
||||
std::time_t t = std::time(0);
|
||||
json logprobs = json(nullptr); // OAI default to null
|
||||
if (probs_output.size() > 0) {
|
||||
logprobs = json{
|
||||
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
json res = json{
|
||||
{"choices", json::array({
|
||||
json{
|
||||
{"text", content},
|
||||
{"index", index},
|
||||
{"logprobs", logprobs},
|
||||
{"finish_reason", nullptr},
|
||||
}
|
||||
})},
|
||||
{"created", t},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "text_completion"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens}
|
||||
}},
|
||||
{"id", oaicompat_cmpl_id}
|
||||
};
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose) {
|
||||
res["__verbose"] = to_json_non_oaicompat_partial();
|
||||
}
|
||||
if (timings.prompt_n >= 0) {
|
||||
res.push_back({ "timings", timings.to_json() });
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_oaicompat_final() {
|
||||
std::time_t t = std::time(0);
|
||||
json logprobs = json(nullptr); // OAI default to null
|
||||
if (!stream && probs_output.size() > 0) {
|
||||
logprobs = json{
|
||||
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
json finish_reason = "length";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = "stop";
|
||||
}
|
||||
json res = json{
|
||||
{"choices", json::array({
|
||||
json{
|
||||
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||
{"index", index},
|
||||
{"logprobs", logprobs},
|
||||
{"finish_reason", finish_reason},
|
||||
}
|
||||
})},
|
||||
{"created", t},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "text_completion"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens}
|
||||
}},
|
||||
{"id", oaicompat_cmpl_id}
|
||||
};
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose) {
|
||||
res["__verbose"] = to_json_non_oaicompat_final();
|
||||
}
|
||||
if (timings.prompt_n >= 0) {
|
||||
res.push_back({ "timings", timings.to_json() });
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_oaicompat_chat_partial() {
|
||||
bool first = n_decoded == 1;
|
||||
std::time_t t = std::time(0);
|
||||
json choices;
|
||||
|
||||
std::vector<json> deltas;
|
||||
auto add_delta = [&](const json& delta) {
|
||||
deltas.push_back({
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", delta},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "chat.completion.chunk"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
}},
|
||||
});
|
||||
};
|
||||
// We have to send an initial update to conform to openai behavior
|
||||
if (first) {
|
||||
add_delta({
|
||||
{"role", "assistant"},
|
||||
{"content", nullptr},
|
||||
});
|
||||
}
|
||||
|
||||
for (const auto& diff : oaicompat_msg_diffs) {
|
||||
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
|
||||
}
|
||||
|
||||
if (!deltas.empty()) {
|
||||
GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1);
|
||||
|
||||
if (probs_output.size() > 0) {
|
||||
deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json{
|
||||
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
deltas[deltas.size() - 1].push_back({ "timings", timings.to_json() });
|
||||
}
|
||||
}
|
||||
|
||||
return deltas;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_oaicompat_chat_final() {
|
||||
std::string finish_reason = "length";
|
||||
common_chat_msg msg;
|
||||
if (!oaicompat_msg.empty()) {
|
||||
msg = oaicompat_msg;
|
||||
}
|
||||
else {
|
||||
msg.role = "assistant";
|
||||
msg.content = content;
|
||||
}
|
||||
if (stop) {
|
||||
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
}
|
||||
|
||||
|
||||
json choice{
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", msg.to_json_oaicompat<json>()},
|
||||
};
|
||||
|
||||
if (!stream && probs_output.size() > 0) {
|
||||
choice["logprobs"] = json{
|
||||
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
|
||||
std::time_t t = std::time(0);
|
||||
|
||||
json res = json{
|
||||
{"choices", json::array({choice})},
|
||||
{"created", t},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "chat.completion"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens}
|
||||
}},
|
||||
{"id", oaicompat_cmpl_id}
|
||||
};
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose) {
|
||||
res["__verbose"] = to_json_non_oaicompat_final();
|
||||
}
|
||||
if (timings.prompt_n >= 0) {
|
||||
res.push_back({ "timings", timings.to_json() });
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_oaicompat_chat_stream() {
|
||||
std::time_t t = std::time(0);
|
||||
std::string finish_reason = "length";
|
||||
if (stop) {
|
||||
//if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
}
|
||||
|
||||
json deltas = json::array();
|
||||
for (const auto& diff : oaicompat_msg_diffs) {
|
||||
deltas.push_back({
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "chat.completion.chunk"},
|
||||
});
|
||||
}
|
||||
|
||||
deltas.push_back({
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "chat.completion.chunk"},
|
||||
});
|
||||
if (include_usage) {
|
||||
// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
|
||||
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
|
||||
deltas.push_back({
|
||||
{"choices", json::array()},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "chat.completion.chunk"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
}},
|
||||
});
|
||||
}
|
||||
if (timings.prompt_n >= 0) {
|
||||
deltas.back().push_back({ "timings", timings.to_json() });
|
||||
}
|
||||
// extra fields for debugging purposes
|
||||
if (verbose && !deltas.empty()) {
|
||||
deltas.front()["__verbose"] = to_json_non_oaicompat_final();
|
||||
}
|
||||
|
||||
return deltas;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_anthropic_final() {
|
||||
std::string stop_reason = "max_tokens";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
|
||||
}
|
||||
|
||||
json content_blocks = json::array();
|
||||
|
||||
common_chat_msg msg;
|
||||
if (!oaicompat_msg.empty()) {
|
||||
msg = oaicompat_msg;
|
||||
}
|
||||
else {
|
||||
msg.role = "assistant";
|
||||
msg.content = content;
|
||||
}
|
||||
|
||||
|
||||
if (!msg.content.empty()) {
|
||||
content_blocks.push_back({
|
||||
{"type", "text"},
|
||||
{"text", msg.content}
|
||||
});
|
||||
}
|
||||
|
||||
for (const auto& tool_call : msg.tool_calls) {
|
||||
json tool_use_block = {
|
||||
{"type", "tool_use"},
|
||||
{"id", tool_call.id},
|
||||
{"name", tool_call.name}
|
||||
};
|
||||
|
||||
try {
|
||||
tool_use_block["input"] = json::parse(tool_call.arguments);
|
||||
}
|
||||
catch (const std::exception&) {
|
||||
tool_use_block["input"] = json::object();
|
||||
}
|
||||
|
||||
content_blocks.push_back(tool_use_block);
|
||||
}
|
||||
|
||||
json res = {
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"type", "message"},
|
||||
{"role", "assistant"},
|
||||
{"content", content_blocks},
|
||||
{"model", oaicompat_model},
|
||||
{"stop_reason", stop_reason},
|
||||
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)},
|
||||
{"usage", {
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
{"output_tokens", n_decoded}
|
||||
}}
|
||||
};
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_anthropic_stream() {
|
||||
json events = json::array();
|
||||
|
||||
std::string stop_reason = "max_tokens";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
|
||||
}
|
||||
|
||||
bool has_text = !oaicompat_msg.content.empty();
|
||||
size_t num_tool_calls = oaicompat_msg.tool_calls.size();
|
||||
|
||||
bool text_block_started = false;
|
||||
std::set<size_t> tool_calls_started;
|
||||
|
||||
for (const auto& diff : oaicompat_msg_diffs) {
|
||||
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (!text_block_started) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", 0},
|
||||
{"content_block", {
|
||||
{"type", "text"},
|
||||
{"text", ""}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
text_block_started = true;
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", 0},
|
||||
{"delta", {
|
||||
{"type", "text_delta"},
|
||||
{"text", diff.content_delta}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index;
|
||||
|
||||
if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) {
|
||||
const auto& full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index];
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", content_block_index},
|
||||
{"content_block", {
|
||||
{"type", "tool_use"},
|
||||
{"id", full_tool_call.id},
|
||||
{"name", full_tool_call.name}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
tool_calls_started.insert(diff.tool_call_index);
|
||||
}
|
||||
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", content_block_index},
|
||||
{"delta", {
|
||||
{"type", "input_json_delta"},
|
||||
{"partial_json", diff.tool_call_delta.arguments}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (has_text) {
|
||||
events.push_back({
|
||||
{"event", "content_block_stop"},
|
||||
{"data", {
|
||||
{"type", "content_block_stop"},
|
||||
{"index", 0}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_tool_calls; i++) {
|
||||
size_t content_block_index = (has_text ? 1 : 0) + i;
|
||||
events.push_back({
|
||||
{"event", "content_block_stop"},
|
||||
{"data", {
|
||||
{"type", "content_block_stop"},
|
||||
{"index", content_block_index}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_delta"},
|
||||
{"data", {
|
||||
{"type", "message_delta"},
|
||||
{"delta", {
|
||||
{"stop_reason", stop_reason},
|
||||
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}
|
||||
}},
|
||||
{"usage", {
|
||||
{"output_tokens", n_decoded}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_stop"},
|
||||
{"data", {
|
||||
{"type", "message_stop"}
|
||||
}}
|
||||
});
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose && !events.empty()) {
|
||||
events.front()["data"]["__verbose"] = to_json_non_oaicompat_final();
|
||||
}
|
||||
// Don't add timings for Anthropic API (breaks spec compliance)
|
||||
if (oaicompat != OAICOMPAT_TYPE_ANTHROPIC && timings.prompt_n >= 0 && !events.empty()) {
|
||||
events.back()["data"]["timings"] = timings.to_json();
|
||||
}
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
json server_task_result::to_json_anthropic_partial() {
|
||||
json events = json::array();
|
||||
bool first = n_decoded == 1;
|
||||
static bool text_block_started = false;
|
||||
|
||||
if (first) {
|
||||
text_block_started = false;
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_start"},
|
||||
{"data", {
|
||||
{"type", "message_start"},
|
||||
{"message", {
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"type", "message"},
|
||||
{"role", "assistant"},
|
||||
{"content", json::array()},
|
||||
{"model", oaicompat_model},
|
||||
{"stop_reason", nullptr},
|
||||
{"stop_sequence", nullptr},
|
||||
{"usage", {
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
{"output_tokens", 0}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
for (const auto& diff : oaicompat_msg_diffs) {
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (!text_block_started) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", 0},
|
||||
{"content_block", {
|
||||
{"type", "text"},
|
||||
{"text", ""}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
text_block_started = true;
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", 0},
|
||||
{"delta", {
|
||||
{"type", "text_delta"},
|
||||
{"text", diff.content_delta}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index;
|
||||
|
||||
if (!diff.tool_call_delta.name.empty()) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", content_block_index},
|
||||
{"content_block", {
|
||||
{"type", "tool_use"},
|
||||
{"id", diff.tool_call_delta.id},
|
||||
{"name", diff.tool_call_delta.name}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", content_block_index},
|
||||
{"delta", {
|
||||
{"type", "input_json_delta"},
|
||||
{"partial_json", diff.tool_call_delta.arguments}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (verbose && !events.empty() && first) {
|
||||
events.front()["data"]["__verbose"] = to_json_non_oaicompat_partial();
|
||||
}
|
||||
|
||||
if (timings.prompt_n >= 0 && !events.empty()) {
|
||||
events.back()["data"]["timings"] = timings.to_json();
|
||||
}
|
||||
|
||||
//if (is_progress && !events.empty()) {
|
||||
// events.back()["data"]["prompt_progress"] = progress.to_json();
|
||||
//}
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
|
||||
size_t server_prompt::size() const {
|
||||
size_t res = data.size();
|
||||
|
||||
for (const auto& checkpoint : checkpoints) {
|
||||
res += checkpoint.size();
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
size_t server_prompt_cache::size() const {
|
||||
size_t res = 0;
|
||||
|
||||
for (const auto& state : states) {
|
||||
res += state.size();
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
size_t server_prompt_cache::n_tokens() const {
|
||||
size_t res = 0;
|
||||
|
||||
for (const auto& state : states) {
|
||||
res += state.n_tokens();
|
||||
}
|
||||
return res;
|
||||
|
||||
}
|
||||
|
||||
bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot) {
|
||||
const auto lcp_best = prompt.tokens.get_common_prefix(ctx, tokens_new);
|
||||
|
||||
float f_keep_best = float(lcp_best.second) / prompt.tokens.size();
|
||||
float sim_best = prompt.tokens.get_tokens_similarity(ctx, tokens_new, prompt.n_kept_prompt, prompt.n_discarded_prompt);
|
||||
LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, prompt.n_kept_prompt, prompt.n_discarded_prompt);
|
||||
|
||||
auto it_best = states.end();
|
||||
|
||||
// find the most similar cached prompt, that would also preserve the most context
|
||||
for (auto it = states.begin(); it != states.end(); ++it) {
|
||||
const auto lcp_cur = it->tokens.get_common_prefix(ctx, tokens_new);
|
||||
const float f_keep_cur = float(lcp_cur.first) / it->tokens.size();
|
||||
const float sim_cur = it->tokens.get_tokens_similarity(ctx, tokens_new, it->n_kept_prompt, it->n_discarded_prompt);
|
||||
if (sim_best < sim_cur) {
|
||||
f_keep_best = f_keep_cur;
|
||||
sim_best = sim_cur;
|
||||
it_best = it;
|
||||
}
|
||||
}
|
||||
|
||||
if (it_best != states.end()) {
|
||||
LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, it_best->n_kept_prompt, it_best->n_discarded_prompt);
|
||||
const size_t size = it_best->data.size();
|
||||
const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot);
|
||||
if (n != size) {
|
||||
LLAMA_LOG_INFO("failed to restore state with size %zu\n", size);
|
||||
return false;
|
||||
}
|
||||
|
||||
it_best->data.clear();
|
||||
it_best->data.shrink_to_fit();
|
||||
|
||||
prompt = std::move(*it_best);
|
||||
|
||||
states.erase(it_best);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
server_prompt* server_prompt_cache::alloc(const server_prompt& prompt, size_t state_size) {
|
||||
for (auto it = states.begin(); it != states.end();) {
|
||||
auto tokens_ctx_shift = server_tokens(prompt.tokens.get_text_tokens(), false); // copy cache tokens
|
||||
tokens_ctx_shift.discard_n_tokens(prompt.n_kept_prompt, prompt.n_discarded_prompt);
|
||||
auto prefix = it->tokens.get_common_prefix(ctx, tokens_ctx_shift);
|
||||
const size_t len = prefix.first;
|
||||
const size_t len_prompt = prefix.second;
|
||||
// first check if the current state is contained fully in the cache
|
||||
if (len_prompt == tokens_ctx_shift.size()) {
|
||||
LLAMA_LOG_INFO("%s", " - prompt is already in the cache, skipping\n");
|
||||
return nullptr;
|
||||
}
|
||||
// next, remove any cached prompts that are fully contained in the current prompt
|
||||
else if (len == it->tokens.size()) {
|
||||
LLAMA_LOG_INFO(" - removing obsolete cached prompt with length %d\n", (int)len);
|
||||
it = states.erase(it);
|
||||
}
|
||||
else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint8_t> state_data;
|
||||
|
||||
// check if we can allocate enough memory for the new state
|
||||
try {
|
||||
state_data.resize(state_size);
|
||||
}
|
||||
catch (const std::bad_alloc& e) {
|
||||
LLAMA_LOG_INFO("failed to allocate memory for prompt cache state: %s\n", e.what());
|
||||
|
||||
limit_size = std::max<size_t>(1, 0.4 * size());
|
||||
|
||||
LLAMA_LOG_INFO(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0));
|
||||
|
||||
update();
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// TODO: for some reason we can't copy server_tokens, so we have to do this workaround
|
||||
auto& cur = states.emplace_back();
|
||||
cur = {
|
||||
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
|
||||
/*.n_keep =*/ prompt.n_kept_prompt,
|
||||
/*.n_discarded_prompt =*/ prompt.n_discarded_prompt,
|
||||
/*.data =*/ std::move(state_data),
|
||||
/*.checkpoints =*/ prompt.checkpoints,
|
||||
};
|
||||
|
||||
return &cur;
|
||||
}
|
||||
|
||||
|
||||
void server_prompt_cache::update() {
|
||||
if (limit_size > 0) {
|
||||
// always keep at least one state, regardless of the limits
|
||||
while (states.size() > 1 && size() > limit_size) {
|
||||
if (states.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
|
||||
|
||||
states.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
// average size per token
|
||||
const float size_per_token = std::max<float>(1.0f, float(size()) / (std::max<size_t>(1, n_tokens())));
|
||||
|
||||
// dynamically increase the token limit if it can fit in the memory limit
|
||||
const size_t limit_tokens_cur = limit_size > 0 ? std::max<size_t>(limit_tokens, limit_size / size_per_token) : limit_tokens;
|
||||
|
||||
LLAMA_LOG_INFO(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
|
||||
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
|
||||
|
||||
for (const auto& state : states) {
|
||||
LLAMA_LOG_INFO(" - prompt %p: %7d tokens, %7d discarded, checkpoints: %2zu, %9.3f MiB\n",
|
||||
(const void*)&state, state.n_tokens(), state.n_discarded_prompt, state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
|
||||
}
|
||||
}
|
||||
216
examples/server/server-task.h
Normal file
216
examples/server/server-task.h
Normal file
@@ -0,0 +1,216 @@
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <list>
|
||||
// TODO: prevent including the whole server-common.h as we only use server_tokens
|
||||
#include "server-common.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum stop_type {
|
||||
STOP_TYPE_NONE,
|
||||
STOP_TYPE_EOS,
|
||||
STOP_TYPE_WORD,
|
||||
STOP_TYPE_LIMIT,
|
||||
};
|
||||
|
||||
|
||||
|
||||
enum server_task_type {
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
SERVER_TASK_TYPE_EMBEDDING,
|
||||
SERVER_TASK_TYPE_RERANK,
|
||||
SERVER_TASK_TYPE_INFILL,
|
||||
SERVER_TASK_TYPE_CANCEL,
|
||||
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
||||
SERVER_TASK_TYPE_METRICS,
|
||||
SERVER_TASK_TYPE_SLOT_SAVE,
|
||||
SERVER_TASK_TYPE_SLOT_RESTORE,
|
||||
SERVER_TASK_TYPE_SLOT_ERASE,
|
||||
SERVER_TASK_TYPE_SET_LORA,
|
||||
};
|
||||
|
||||
enum oaicompat_type {
|
||||
OAICOMPAT_TYPE_NONE,
|
||||
OAICOMPAT_TYPE_CHAT,
|
||||
OAICOMPAT_TYPE_COMPLETION,
|
||||
OAICOMPAT_TYPE_EMBEDDING,
|
||||
OAICOMPAT_TYPE_ANTHROPIC,
|
||||
};
|
||||
|
||||
|
||||
struct server_task {
|
||||
int id = -1; // to be filled by server_queue
|
||||
int id_multi = -1;
|
||||
int id_target = -1;
|
||||
//int id_slot = -1;
|
||||
|
||||
// used by SERVER_TASK_TYPE_INFERENCE
|
||||
server_tokens tokens;
|
||||
|
||||
server_task_type type;
|
||||
json data;
|
||||
|
||||
bool infill = false;
|
||||
bool embedding = false;
|
||||
|
||||
server_task() = default;
|
||||
server_task(server_task_type type) : type(type) {}
|
||||
|
||||
};
|
||||
|
||||
struct result_timings {
|
||||
int32_t prompt_n = -1;
|
||||
double prompt_ms;
|
||||
double prompt_per_token_ms;
|
||||
double prompt_per_second;
|
||||
|
||||
int32_t predicted_n = -1;
|
||||
double predicted_ms;
|
||||
double predicted_per_token_ms;
|
||||
double predicted_per_second;
|
||||
int32_t n_ctx = 0;
|
||||
int32_t n_past = 0;
|
||||
|
||||
// Optional speculative metrics - only included when > 0
|
||||
int32_t draft_n = 0;
|
||||
int32_t draft_n_accepted = 0;
|
||||
|
||||
json to_json() const;
|
||||
};
|
||||
|
||||
struct server_task_result {
|
||||
int id = -1;
|
||||
int id_multi = -1;
|
||||
|
||||
json data;
|
||||
|
||||
bool stop;
|
||||
bool error;
|
||||
bool final_result = false;
|
||||
result_timings timings;
|
||||
// OAI-compat fields
|
||||
//bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
common_chat_msg oaicompat_msg;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
|
||||
int index = 0;
|
||||
|
||||
std::string content;
|
||||
std::vector<llama_token> tokens;
|
||||
|
||||
bool stream;
|
||||
bool include_usage;
|
||||
std::string prompt;
|
||||
//slot_params generation_params;
|
||||
|
||||
bool truncated;
|
||||
int32_t n_decoded;
|
||||
int32_t n_prompt_tokens;
|
||||
int32_t n_tokens_cached;
|
||||
bool has_new_line;
|
||||
std::string stopping_word;
|
||||
|
||||
bool post_sampling_probs = false;
|
||||
std::vector<completion_token_output> probs_output;
|
||||
std::vector<std::string> response_fields;
|
||||
|
||||
//slot_params generation_params;
|
||||
|
||||
bool verbose = false;
|
||||
|
||||
|
||||
int get_index() {
|
||||
return index;
|
||||
}
|
||||
|
||||
bool is_stop() {
|
||||
return true; // in stream mode, final responses are considered stop
|
||||
}
|
||||
|
||||
json to_json_final();
|
||||
|
||||
json to_json_partial();
|
||||
|
||||
json to_json_non_oaicompat_partial();
|
||||
|
||||
json to_json_non_oaicompat_final();
|
||||
|
||||
json to_json_oaicompat_partial();
|
||||
|
||||
json to_json_oaicompat_final();
|
||||
|
||||
json to_json_oaicompat_chat_partial();
|
||||
|
||||
json to_json_oaicompat_chat_final();
|
||||
|
||||
json to_json_oaicompat_chat_stream();
|
||||
|
||||
json to_json_anthropic_final();
|
||||
|
||||
json to_json_anthropic_stream();
|
||||
|
||||
json to_json_anthropic_partial();
|
||||
};
|
||||
|
||||
|
||||
struct server_prompt_checkpoint {
|
||||
llama_pos pos_min;
|
||||
llama_pos pos_max;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
|
||||
size_t size() const {
|
||||
return data.size();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct server_prompt {
|
||||
server_tokens tokens;
|
||||
int n_kept_prompt;
|
||||
int n_discarded_prompt;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
|
||||
std::list<server_prompt_checkpoint> checkpoints;
|
||||
|
||||
size_t size() const;
|
||||
|
||||
int n_tokens() const {
|
||||
return tokens.size();
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt_cache {
|
||||
server_prompt_cache(llama_context* ctx, int32_t limit_size_mib, size_t limit_tokens) {
|
||||
this->ctx = ctx;
|
||||
this->limit_size = 1024ull * 1024ull * (limit_size_mib < 0 ? 0 : limit_size_mib);
|
||||
this->limit_tokens = limit_tokens;
|
||||
}
|
||||
|
||||
std::list<server_prompt> states;
|
||||
|
||||
// in bytes, 0 = no limit
|
||||
size_t limit_size = 0;
|
||||
|
||||
// in tokens, 0 = no limit
|
||||
size_t limit_tokens = 0;
|
||||
llama_context* ctx;
|
||||
size_t size() const;
|
||||
|
||||
size_t n_tokens() const;
|
||||
|
||||
server_prompt* alloc(const server_prompt& prompt, size_t state_size);
|
||||
|
||||
bool load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot);
|
||||
|
||||
void update();
|
||||
};
|
||||
File diff suppressed because it is too large
Load Diff
@@ -773,7 +773,6 @@ void launch_fattn(
|
||||
size_t nb11 = K->nb[1];
|
||||
size_t nb12 = K->nb[2];
|
||||
size_t nb13 = K->nb[3];
|
||||
|
||||
char * V_data = (char *) V->data;
|
||||
size_t nb21 = V->nb[1];
|
||||
size_t nb22 = V->nb[2];
|
||||
|
||||
Reference in New Issue
Block a user