mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-23 16:09:18 +00:00
Tool calls support from mainline (#723)
* Tool calls support from mainline * update cmake * revert api for /completions * Fix broken thinking process for gpt-oss * add missing args and fix webui bugs * add missing args and fix webui bugs2 * Fix reasoning format error * add usage * change default post_sampling_probs to true * add back generated_text * Remove server endpoints tests * add log * Chat fixes * Remove logs * webui: revert extra handling of thinking process --------- Co-authored-by: firecoperana <firecoperana> Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -52,17 +52,13 @@ set(TARGET common)
|
||||
|
||||
add_library(${TARGET} STATIC
|
||||
base64.hpp
|
||||
chat-template.hpp
|
||||
common.h
|
||||
chat.cpp
|
||||
chat.h
|
||||
chat-parser.cpp
|
||||
chat-parser.h
|
||||
common.cpp
|
||||
chat.h
|
||||
chat.cpp
|
||||
chat-parser.h
|
||||
chat-parser.cpp
|
||||
json-partial.h
|
||||
json-partial.cpp
|
||||
regex-partial.h
|
||||
regex-partial.cpp
|
||||
sampling.h
|
||||
sampling.cpp
|
||||
console.h
|
||||
@@ -70,13 +66,19 @@ add_library(${TARGET} STATIC
|
||||
grammar-parser.h
|
||||
grammar-parser.cpp
|
||||
json.hpp
|
||||
json-partial.h
|
||||
json-partial.cpp
|
||||
llguidance.cpp
|
||||
json-schema-to-grammar.cpp
|
||||
train.h
|
||||
train.cpp
|
||||
minja.hpp
|
||||
minja/chat-template.hpp
|
||||
minja/minja.hpp
|
||||
ngram-cache.h
|
||||
ngram-cache.cpp
|
||||
speculative.cpp
|
||||
regex-partial.cpp
|
||||
regex-partial.h
|
||||
)
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
@@ -94,6 +96,33 @@ if (LLAMA_CURL)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY})
|
||||
endif ()
|
||||
|
||||
if (LLAMA_LLGUIDANCE)
|
||||
include(ExternalProject)
|
||||
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
|
||||
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
|
||||
ExternalProject_Add(llguidance_ext
|
||||
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
|
||||
# v0.6.12:
|
||||
GIT_TAG ced1c9023d47ec194fa977932d35ce65c2ebfc09
|
||||
PREFIX ${CMAKE_BINARY_DIR}/llguidance
|
||||
SOURCE_DIR ${LLGUIDANCE_SRC}
|
||||
BUILD_IN_SOURCE TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND cargo build --release
|
||||
INSTALL_COMMAND ""
|
||||
BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/libllguidance.a ${LLGUIDANCE_PATH}/llguidance.h
|
||||
UPDATE_COMMAND ""
|
||||
)
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_LLGUIDANCE)
|
||||
|
||||
add_library(llguidance STATIC IMPORTED)
|
||||
set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/libllguidance.a)
|
||||
add_dependencies(llguidance llguidance_ext)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance)
|
||||
endif ()
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC .)
|
||||
target_compile_features (${TARGET} PUBLIC cxx_std_11)
|
||||
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
|
||||
|
||||
@@ -1,54 +1,69 @@
|
||||
// Chat parser implementation
|
||||
#include "chat-parser.h"
|
||||
#include "../examples/server/parsers/kimi_k2_parser.hpp"
|
||||
#include "json.hpp"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
|
||||
: input_(input), is_partial_(is_partial), syntax_(syntax) {
|
||||
// Initialize result with default role
|
||||
: input_(input), is_partial_(is_partial), syntax_(syntax)
|
||||
{
|
||||
result_.role = "assistant";
|
||||
|
||||
while (true) {
|
||||
std::string id = std::to_string(std::rand());
|
||||
if (input.find(id) == std::string::npos) {
|
||||
healing_marker_ = id;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string common_chat_msg_parser::str(const common_string_range & rng) const {
|
||||
if (rng.begin > input_.size() || rng.end > input_.size()) {
|
||||
throw std::runtime_error("Range out of bounds");
|
||||
}
|
||||
GGML_ASSERT(rng.begin <= rng.end);
|
||||
return input_.substr(rng.begin, rng.end - rng.begin);
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::add_content(const std::string & content) {
|
||||
void common_chat_msg_parser::add_content(const std::string &content) {
|
||||
result_.content += content;
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::add_reasoning_content(const std::string & reasoning_content) {
|
||||
void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) {
|
||||
result_.reasoning_content += reasoning_content;
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::add_tool_call(const common_chat_tool_call & tool_call) {
|
||||
result_.tool_calls.push_back(tool_call);
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) {
|
||||
if (name.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
common_chat_tool_call tool_call;
|
||||
tool_call.name = name;
|
||||
tool_call.arguments = arguments;
|
||||
tool_call.id = id;
|
||||
|
||||
|
||||
// LOG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str());
|
||||
result_.tool_calls.emplace_back(tool_call);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
|
||||
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
|
||||
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
|
||||
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
|
||||
std::string arguments = "";
|
||||
if (tool_call.contains("arguments")) {
|
||||
if (tool_call.at("arguments").is_object()) {
|
||||
arguments = tool_call.at("arguments").dump();
|
||||
} else {
|
||||
arguments = tool_call.at("arguments");
|
||||
}
|
||||
}
|
||||
|
||||
return add_tool_call(name, id, arguments);
|
||||
}
|
||||
|
||||
@@ -60,25 +75,65 @@ bool common_chat_msg_parser::add_tool_calls(const json & arr) {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::clear_tools() {
|
||||
result_.tool_calls.clear();
|
||||
void common_chat_msg_parser::finish() {
|
||||
if (!is_partial_ && pos_ != input_.size()) {
|
||||
throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_));
|
||||
}
|
||||
}
|
||||
|
||||
std::string common_chat_msg_parser::consume_rest() {
|
||||
auto rest = input_.substr(pos_);
|
||||
pos_ = input_.size();
|
||||
return rest;
|
||||
bool common_chat_msg_parser::consume_spaces() {
|
||||
const auto length = input_.size();
|
||||
auto consumed = false;
|
||||
while (pos_ < length && std::isspace(input_[pos_])) {
|
||||
++pos_;
|
||||
consumed = true;
|
||||
}
|
||||
return consumed;
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
|
||||
if (pos_ + literal.size() <= input_.size()) {
|
||||
if (input_.substr(pos_, literal.size()) == literal) {
|
||||
pos_ += literal.size();
|
||||
return true;
|
||||
auto pos = pos_;
|
||||
for (auto i = 0u; i < literal.size(); ++i) {
|
||||
if (pos >= input_.size()) {
|
||||
return false;
|
||||
}
|
||||
if (input_[pos] != literal[i]) {
|
||||
return false;
|
||||
}
|
||||
++pos;
|
||||
}
|
||||
pos_ = pos;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal(const std::string & literal) {
|
||||
auto idx = input_.find(literal, pos_);
|
||||
if (idx != std::string::npos) {
|
||||
find_regex_result res;
|
||||
res.prelude = input_.substr(pos_, idx - pos_);
|
||||
auto end = idx + literal.size();
|
||||
res.groups.emplace_back(common_string_range{idx, end});
|
||||
move_to(end);
|
||||
return res;
|
||||
}
|
||||
if (is_partial_) {
|
||||
idx = string_find_partial_stop(input_, literal);
|
||||
if (idx != std::string::npos && idx >= pos_) {
|
||||
find_regex_result res;
|
||||
res.prelude = input_.substr(pos_, idx - pos_);
|
||||
auto end = input_.size();
|
||||
res.groups.emplace_back(common_string_range{idx, end});
|
||||
move_to(end);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::consume_literal(const std::string & literal) {
|
||||
if (!try_consume_literal(literal)) {
|
||||
throw common_chat_msg_partial_exception(literal);
|
||||
}
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
|
||||
@@ -97,7 +152,6 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think
|
||||
add_reasoning_content(stripped_reasoning);
|
||||
}
|
||||
};
|
||||
|
||||
if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
|
||||
if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
|
||||
if (auto res = try_find_literal(end_think)) {
|
||||
@@ -109,198 +163,73 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think
|
||||
if (!rest.empty()) {
|
||||
handle_reasoning(rest, /* closed */ !is_partial());
|
||||
}
|
||||
// Allow unclosed thinking tags for now (following original llama.cpp)
|
||||
// Allow unclosed thinking tags, for now (https://github.com/ggml-org/llama.cpp/issues/13812, https://github.com/ggml-org/llama.cpp/issues/13877)
|
||||
// if (!syntax_.thinking_forced_open) {
|
||||
// throw common_chat_msg_partial_exception(end_think);
|
||||
// }
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal_legacy(const std::string & literal) {
|
||||
auto idx = input_.find(literal, pos_);
|
||||
if (idx != std::string::npos) {
|
||||
find_regex_result res;
|
||||
res.prelude = input_.substr(pos_, idx - pos_);
|
||||
auto end = idx + literal.size();
|
||||
res.groups.emplace_back(common_string_range{idx, end});
|
||||
move_to(end);
|
||||
return res;
|
||||
}
|
||||
|
||||
if (is_partial_) {
|
||||
idx = string_find_partial_stop(input_, literal);
|
||||
if (idx != std::string::npos && idx >= pos_) {
|
||||
find_regex_result res;
|
||||
res.prelude = input_.substr(pos_, idx - pos_);
|
||||
auto end = input_.size();
|
||||
res.groups.emplace_back(common_string_range{idx, end});
|
||||
move_to(end);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::parse() {
|
||||
switch (syntax_.format) {
|
||||
case COMMON_CHAT_FORMAT_KIMI_K2:
|
||||
parse_kimi_k2_format();
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
|
||||
parse_deepseek_r1_format();
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GENERIC:
|
||||
parse_generic_format();
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
|
||||
add_content(consume_rest());
|
||||
break;
|
||||
default:
|
||||
// Fallback to content-only for now
|
||||
add_content(consume_rest());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::parse_kimi_k2_format() {
|
||||
json tool_calls_json = kimi_k2::parse_tool_calls(input_);
|
||||
|
||||
if (is_partial_ && kimi_k2::is_partial_content_advanced(input_)) {
|
||||
throw common_chat_msg_partial_exception("partial structured content detected");
|
||||
}
|
||||
|
||||
bool has_function_syntax = input_.find("functions.") != std::string::npos;
|
||||
bool parsing_succeeded = !tool_calls_json.empty();
|
||||
|
||||
if (has_function_syntax && !parsing_succeeded) {
|
||||
throw std::runtime_error("malformed function call syntax detected");
|
||||
}
|
||||
|
||||
if (!tool_calls_json.empty()) {
|
||||
for (const auto& tc_json : tool_calls_json) {
|
||||
try {
|
||||
common_chat_tool_call tc;
|
||||
tc.id = tc_json.value("id", "");
|
||||
|
||||
if (!tc_json.contains("function") || !tc_json["function"].contains("name")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
tc.name = tc_json["function"]["name"];
|
||||
if (tc.name.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
tc.arguments = tc_json["function"]["arguments"];
|
||||
|
||||
if (!is_partial_ && !tc.arguments.empty()) {
|
||||
try {
|
||||
auto parsed = json::parse(tc.arguments);
|
||||
(void)parsed;
|
||||
} catch (const std::exception&) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
add_tool_call(tc);
|
||||
} catch (const std::exception&) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
add_content(kimi_k2::clean_content(input_));
|
||||
} else {
|
||||
add_content(input_);
|
||||
}
|
||||
std::string common_chat_msg_parser::consume_rest() {
|
||||
auto rest = input_.substr(pos_);
|
||||
pos_ = input_.size();
|
||||
return rest;
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::parse_generic_format() {
|
||||
add_content(consume_rest());
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::parse_deepseek_r1_format() {
|
||||
// Delegate to the main chat.cpp function which has the corrected implementation
|
||||
// This follows the original llama.cpp pattern where chat-parser delegates to chat.cpp
|
||||
common_chat_parse_deepseek_r1(*this);
|
||||
}
|
||||
|
||||
|
||||
void common_chat_msg_parser::finish() {
|
||||
// Any final processing can go here
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_msg_parser::result_and_reset() {
|
||||
auto msg = result_;
|
||||
result_ = common_chat_msg();
|
||||
result_.role = "assistant";
|
||||
pos_ = 0;
|
||||
return msg;
|
||||
}
|
||||
|
||||
// Content-only parsing for fallback scenarios
|
||||
|
||||
// Format detection from chat template patterns (focused on DeepSeek R1 and Kimi K2)
|
||||
common_chat_format common_chat_format_detect(const std::string & chat_template) {
|
||||
if (chat_template.empty()) {
|
||||
return COMMON_CHAT_FORMAT_GENERIC;
|
||||
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
|
||||
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) {
|
||||
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
|
||||
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Detect DeepSeek R1 format (following original llama.cpp detection logic)
|
||||
if (chat_template.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
||||
return COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
||||
}
|
||||
|
||||
// Detect Kimi K2 format (our custom format)
|
||||
if (chat_template.find("kimi") != std::string::npos ||
|
||||
chat_template.find("Kimi") != std::string::npos ||
|
||||
chat_template.find("functions.") != std::string::npos) {
|
||||
return COMMON_CHAT_FORMAT_KIMI_K2;
|
||||
}
|
||||
|
||||
// Default to generic format for unknown templates
|
||||
return COMMON_CHAT_FORMAT_GENERIC;
|
||||
}
|
||||
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
|
||||
pos_ = m.groups[0].end;
|
||||
|
||||
// Progressive parsing primitive - find literal (following original llama.cpp pattern)
|
||||
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal(const std::string & literal) {
|
||||
auto idx = input_.find(literal, pos_);
|
||||
if (idx != std::string::npos) {
|
||||
find_regex_result res;
|
||||
res.prelude = input_.substr(pos_, idx - pos_);
|
||||
auto end = idx + literal.size();
|
||||
res.groups.emplace_back(common_string_range{idx, end});
|
||||
move_to(end);
|
||||
return res;
|
||||
if (add_prelude_to_content) {
|
||||
add_content(prelude);
|
||||
}
|
||||
|
||||
if (is_partial_) {
|
||||
idx = string_find_partial_stop(input_, literal);
|
||||
if (idx != std::string::npos && idx >= pos_) {
|
||||
find_regex_result res;
|
||||
res.prelude = input_.substr(pos_, idx - pos_);
|
||||
auto end = input_.size();
|
||||
res.groups.emplace_back(common_string_range{idx, end});
|
||||
move_to(end);
|
||||
return res;
|
||||
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
|
||||
if (is_partial()) {
|
||||
throw common_chat_msg_partial_exception(regex.str());
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
return std::nullopt;
|
||||
return find_regex_result{prelude, m.groups};
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::consume_spaces() {
|
||||
bool consumed = false;
|
||||
while (pos_ < input_.length() && std::isspace(input_[pos_])) {
|
||||
pos_++;
|
||||
consumed = true;
|
||||
common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
|
||||
if (auto result = try_consume_regex(regex)) {
|
||||
return *result;
|
||||
}
|
||||
return consumed;
|
||||
throw common_chat_msg_partial_exception(regex.str());
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::set_healing_marker(const std::string & marker) {
|
||||
healing_marker_ = marker;
|
||||
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
|
||||
auto m = regex.search(input_, pos_);
|
||||
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
|
||||
if (is_partial()) {
|
||||
throw common_chat_msg_partial_exception(regex.str());
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
if (m.groups[0].begin != pos_) {
|
||||
// Didn't match at the current position.
|
||||
return std::nullopt;
|
||||
}
|
||||
pos_ = m.groups[0].end;
|
||||
|
||||
return find_regex_result {
|
||||
/* .prelude = */ "",
|
||||
m.groups,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
// Enhanced JSON parsing methods (following original llama.cpp patterns exactly)
|
||||
std::optional<common_json> common_chat_msg_parser::try_consume_json() {
|
||||
auto it = input_.cbegin() + pos_;
|
||||
const auto end = input_.cend();
|
||||
@@ -327,8 +256,8 @@ common_json common_chat_msg_parser::consume_json() {
|
||||
}
|
||||
|
||||
common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>>& args_paths,
|
||||
const std::vector<std::vector<std::string>>& content_paths
|
||||
const std::vector<std::vector<std::string>> & args_paths,
|
||||
const std::vector<std::vector<std::string>> & content_paths
|
||||
) {
|
||||
if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) {
|
||||
return *result;
|
||||
@@ -337,8 +266,8 @@ common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json
|
||||
}
|
||||
|
||||
std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parser::try_consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>>& args_paths,
|
||||
const std::vector<std::vector<std::string>>& content_paths
|
||||
const std::vector<std::vector<std::string>> & args_paths,
|
||||
const std::vector<std::vector<std::string>> & content_paths
|
||||
) {
|
||||
auto partial = try_consume_json();
|
||||
if (!partial) {
|
||||
@@ -366,137 +295,99 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
|
||||
/* .is_partial = */ false,
|
||||
};
|
||||
}
|
||||
// TODO: Implement full path-based argument dumping logic from original
|
||||
// For now, return the parsed JSON as-is
|
||||
return consume_json_result {
|
||||
partial->json,
|
||||
/* .is_partial = */ false,
|
||||
};
|
||||
}
|
||||
|
||||
// Has healing marker - this is partial JSON
|
||||
// TODO: Implement sophisticated partial JSON handling with path-based dumping
|
||||
// For now, return partial result
|
||||
return consume_json_result {
|
||||
partial->json,
|
||||
/* .is_partial = */ true,
|
||||
};
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::detect_partial_function_call(const std::string& content) {
|
||||
if (content.empty()) return false;
|
||||
|
||||
// Enhanced partial detection patterns
|
||||
static const std::vector<std::string> partial_patterns = {
|
||||
"functions",
|
||||
"functions.",
|
||||
"<tool_call",
|
||||
"<tool_call>",
|
||||
"<invoke",
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_call_begin|>"
|
||||
};
|
||||
|
||||
for (const auto& pattern : partial_patterns) {
|
||||
if (content.substr(0, pattern.length()) == pattern && content.length() <= pattern.length() + 50) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
LOG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
|
||||
|
||||
void common_chat_msg_parser::handle_partial_detection() {
|
||||
if (!is_partial_) return;
|
||||
|
||||
// Check for various partial patterns
|
||||
std::string remaining = input_.substr(pos_);
|
||||
|
||||
if (remaining.empty()) return;
|
||||
|
||||
// Detect partial function calls
|
||||
if (detect_partial_function_call(remaining)) {
|
||||
set_healing_marker(remaining);
|
||||
throw common_chat_msg_partial_exception("partial function call detected");
|
||||
}
|
||||
|
||||
// Enhanced partial JSON detection
|
||||
if (remaining.find('{') != std::string::npos) {
|
||||
size_t brace_pos = remaining.find('{');
|
||||
std::string json_part = remaining.substr(brace_pos);
|
||||
|
||||
// Check if JSON is incomplete
|
||||
int brace_count = 0;
|
||||
bool in_string = false;
|
||||
bool escaped = false;
|
||||
bool is_incomplete = true;
|
||||
|
||||
for (size_t i = 0; i < json_part.length(); i++) {
|
||||
char c = json_part[i];
|
||||
|
||||
if (!escaped) {
|
||||
if (c == '"' && !in_string) {
|
||||
in_string = true;
|
||||
} else if (c == '"' && in_string) {
|
||||
in_string = false;
|
||||
} else if (!in_string) {
|
||||
if (c == '{') brace_count++;
|
||||
else if (c == '}') brace_count--;
|
||||
auto found_healing_marker = false;
|
||||
std::vector<std::string> path;
|
||||
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
|
||||
if (is_arguments_path(path)) {
|
||||
auto arguments = j.dump();
|
||||
if (is_partial() && !partial->healing_marker.marker.empty()) {
|
||||
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
|
||||
if (idx != std::string::npos) {
|
||||
arguments.resize(idx);
|
||||
found_healing_marker = true;
|
||||
}
|
||||
if (arguments == "\"") {
|
||||
// This happens because of completing `:"$magic` after `"arguments"`
|
||||
arguments = "";
|
||||
}
|
||||
}
|
||||
|
||||
escaped = (!escaped && c == '\\');
|
||||
|
||||
if (brace_count == 0) {
|
||||
is_incomplete = false;
|
||||
break;
|
||||
return arguments;
|
||||
}
|
||||
if (is_content_path(path)) {
|
||||
if (!j.is_string()) {
|
||||
throw std::runtime_error("Content path must be a string");
|
||||
}
|
||||
std::string str = j;
|
||||
auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string
|
||||
if (idx != std::string::npos) {
|
||||
str.resize(idx);
|
||||
found_healing_marker = true;
|
||||
}
|
||||
return str;
|
||||
}
|
||||
|
||||
if (is_incomplete) {
|
||||
set_healing_marker(json_part);
|
||||
throw common_chat_msg_partial_exception("partial JSON detected");
|
||||
if (j.is_object()) {
|
||||
auto obj = json::object();
|
||||
for (const auto & p : j.items()) {
|
||||
const auto & key = p.key();
|
||||
const auto & value = p.value();
|
||||
const std::string key_str = key; // NOLINT
|
||||
auto idx = key_str.find(healing_marker_);
|
||||
if (idx != std::string::npos) {
|
||||
found_healing_marker = true;
|
||||
break;
|
||||
}
|
||||
path.push_back(key_str);
|
||||
if (value.is_string()) {
|
||||
const std::string value_str = value;
|
||||
if (value_str.find(healing_marker_) != std::string::npos) {
|
||||
found_healing_marker = true;
|
||||
if (is_content_path(path)) {
|
||||
if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) {
|
||||
// The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair.
|
||||
obj[key] = remove_unsupported_healings_and_dump_args(value);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
obj[key] = value;
|
||||
} else {
|
||||
obj[key] = remove_unsupported_healings_and_dump_args(value);
|
||||
}
|
||||
path.pop_back();
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Regex-based parsing methods (ported from original llama.cpp)
|
||||
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) {
|
||||
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
|
||||
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
|
||||
pos_ = m.groups[0].end;
|
||||
|
||||
if (add_prelude_to_content) {
|
||||
add_content(prelude);
|
||||
}
|
||||
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
|
||||
if (is_partial()) {
|
||||
throw common_chat_msg_partial_exception(regex.str());
|
||||
if (j.is_array()) {
|
||||
auto arr = json::array();
|
||||
for (const auto & value : j) {
|
||||
if (value.is_string()) {
|
||||
std::string str = value;
|
||||
auto idx = str.find(healing_marker_);
|
||||
if (idx != std::string::npos) {
|
||||
// Don't heal array values that aren't in the arguments.
|
||||
found_healing_marker = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
arr.push_back(remove_unsupported_healings_and_dump_args(value));
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
return find_regex_result{prelude, m.groups};
|
||||
return j;
|
||||
};
|
||||
|
||||
auto cleaned = remove_unsupported_healings_and_dump_args(partial->json);
|
||||
LOG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
|
||||
return consume_json_result {
|
||||
cleaned,
|
||||
/* .is_partial = */ found_healing_marker,
|
||||
};
|
||||
}
|
||||
|
||||
common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
|
||||
auto result = try_find_regex(regex);
|
||||
if (!result) {
|
||||
throw std::runtime_error("Expected regex not found: " + regex.str());
|
||||
}
|
||||
return *result;
|
||||
void common_chat_msg_parser::clear_tools() {
|
||||
result_.tool_calls.clear();
|
||||
}
|
||||
|
||||
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
|
||||
return try_find_regex(regex, pos_, false);
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::consume_literal(const std::string & literal) {
|
||||
if (!try_consume_literal(literal)) {
|
||||
throw std::runtime_error("Expected literal not found: " + literal);
|
||||
}
|
||||
}
|
||||
|
||||
// Get format name for debugging/logging (implemented in chat.cpp)
|
||||
@@ -1,14 +1,18 @@
|
||||
// Chat parser with builder pattern for incremental parsing
|
||||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
#include "json-partial.h"
|
||||
#include "json.hpp"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
class common_chat_msg_partial_exception : public std::runtime_error {
|
||||
public:
|
||||
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
|
||||
};
|
||||
|
||||
class common_chat_msg_parser {
|
||||
std::string input_;
|
||||
@@ -20,14 +24,7 @@ class common_chat_msg_parser {
|
||||
common_chat_msg result_;
|
||||
|
||||
public:
|
||||
struct find_regex_result {
|
||||
std::string prelude;
|
||||
std::vector<common_string_range> groups;
|
||||
};
|
||||
|
||||
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
|
||||
// Accessors
|
||||
const std::string & input() const { return input_; }
|
||||
size_t pos() const { return pos_; }
|
||||
const std::string & healing_marker() const { return healing_marker_; }
|
||||
@@ -35,14 +32,12 @@ class common_chat_msg_parser {
|
||||
const common_chat_msg & result() const { return result_; }
|
||||
const common_chat_syntax & syntax() const { return syntax_; }
|
||||
|
||||
// Position manipulation
|
||||
void move_to(size_t pos) {
|
||||
if (pos > input_.size()) {
|
||||
throw std::runtime_error("Invalid position!");
|
||||
}
|
||||
pos_ = pos;
|
||||
}
|
||||
|
||||
void move_back(size_t n) {
|
||||
if (pos_ < n) {
|
||||
throw std::runtime_error("Can't move back that far!");
|
||||
@@ -53,84 +48,72 @@ class common_chat_msg_parser {
|
||||
// Get the substring of the input at the given range
|
||||
std::string str(const common_string_range & rng) const;
|
||||
|
||||
// Content manipulation
|
||||
// Appends to the result.content field
|
||||
void add_content(const std::string & content);
|
||||
|
||||
// Appends to the result.reasoning_content field
|
||||
void add_reasoning_content(const std::string & reasoning_content);
|
||||
|
||||
// Tool call manipulation
|
||||
void add_tool_call(const common_chat_tool_call & tool_call);
|
||||
// Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
|
||||
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
|
||||
bool add_tool_call(const json & tool_call);
|
||||
bool add_tool_calls(const json & arr);
|
||||
void clear_tools();
|
||||
|
||||
// Parsing utilities
|
||||
std::string consume_rest();
|
||||
bool try_consume_literal(const std::string & literal);
|
||||
void consume_literal(const std::string & literal);
|
||||
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
|
||||
// Adds a tool call using the "name", "id" and "arguments" fields of the json object
|
||||
bool add_tool_call(const nlohmann::ordered_json & tool_call);
|
||||
|
||||
// Regex-based parsing methods (new)
|
||||
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
|
||||
find_regex_result consume_regex(const common_regex & regex);
|
||||
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
|
||||
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
|
||||
bool add_tool_calls(const nlohmann::ordered_json & arr);
|
||||
|
||||
// Progressive parsing primitives (for Phase 4)
|
||||
std::optional<find_regex_result> try_find_literal(const std::string & literal);
|
||||
bool consume_spaces();
|
||||
void set_healing_marker(const std::string & marker);
|
||||
|
||||
|
||||
// Main parsing entry point
|
||||
void parse();
|
||||
|
||||
// Finishing
|
||||
void finish();
|
||||
|
||||
// Result extraction
|
||||
common_chat_msg result_and_reset();
|
||||
bool consume_spaces();
|
||||
|
||||
// Advanced JSON parsing (following original llama.cpp patterns)
|
||||
struct consume_json_result {
|
||||
json value;
|
||||
bool is_partial;
|
||||
void consume_literal(const std::string & literal);
|
||||
|
||||
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
|
||||
|
||||
std::string consume_rest();
|
||||
|
||||
struct find_regex_result {
|
||||
std::string prelude;
|
||||
std::vector<common_string_range> groups;
|
||||
};
|
||||
|
||||
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
|
||||
|
||||
bool try_consume_literal(const std::string & literal);
|
||||
|
||||
std::optional<find_regex_result> try_find_literal(const std::string & literal);
|
||||
|
||||
find_regex_result consume_regex(const common_regex & regex);
|
||||
|
||||
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
|
||||
|
||||
std::optional<common_json> try_consume_json();
|
||||
common_json consume_json();
|
||||
consume_json_result consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>>& args_paths = {},
|
||||
const std::vector<std::vector<std::string>>& content_paths = {}
|
||||
);
|
||||
std::optional<consume_json_result> try_consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>>& args_paths = {},
|
||||
const std::vector<std::vector<std::string>>& content_paths = {}
|
||||
);
|
||||
|
||||
private:
|
||||
// Internal parsing helpers
|
||||
void parse_kimi_k2_format();
|
||||
void parse_deepseek_r1_format();
|
||||
void parse_generic_format();
|
||||
|
||||
|
||||
// JSON parsing utilities (enhanced streaming support)
|
||||
struct json_parse_result {
|
||||
json value;
|
||||
bool success;
|
||||
struct consume_json_result {
|
||||
nlohmann::ordered_json value;
|
||||
bool is_partial;
|
||||
std::string healing_marker;
|
||||
};
|
||||
|
||||
// Partial detection utilities
|
||||
bool detect_partial_function_call(const std::string& content);
|
||||
void handle_partial_detection();
|
||||
/*
|
||||
Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings.
|
||||
|
||||
// Legacy find_literal for compatibility
|
||||
std::optional<find_regex_result> try_find_literal_legacy(const std::string & literal);
|
||||
By default, object keys can't be truncated, nor can string values (their corresponding key is removed,
|
||||
e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}`
|
||||
|
||||
But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings
|
||||
- with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}`
|
||||
- with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}`
|
||||
*/
|
||||
consume_json_result consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>> & args_paths = {},
|
||||
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||
);
|
||||
std::optional<consume_json_result> try_consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>> & args_paths = {},
|
||||
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||
);
|
||||
|
||||
void clear_tools();
|
||||
};
|
||||
|
||||
// Main parsing function (public API)
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
|
||||
// Content-only parsing for fallback scenarios (static internal function)
|
||||
|
||||
@@ -1,249 +0,0 @@
|
||||
/*
|
||||
Copyright 2024 Google LLC
|
||||
|
||||
Use of this source code is governed by an MIT-style
|
||||
license that can be found in the LICENSE file or at
|
||||
https://opensource.org/licenses/MIT.
|
||||
*/
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "minja.hpp"
|
||||
#include <json.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
namespace minja {
|
||||
|
||||
class chat_template {
|
||||
public:
|
||||
|
||||
private:
|
||||
bool supports_tools_ = true;
|
||||
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||
bool requires_object_arguments_ = false;
|
||||
bool supports_system_role_ = true;
|
||||
bool supports_parallel_tool_calls_ = false;
|
||||
std::string source_;
|
||||
std::string bos_token_;
|
||||
std::string eos_token_;
|
||||
std::shared_ptr<minja::TemplateNode> template_root_;
|
||||
|
||||
std::string try_render(
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
|
||||
{
|
||||
try {
|
||||
auto prompt = apply(messages, tools, add_generation_prompt, extra_context);
|
||||
// fprintf(stderr, "Prompt: %s\n", prompt.c_str());
|
||||
return prompt;
|
||||
} catch (const std::exception & e) {
|
||||
// fprintf(stderr, "Error: %s\n", e.what());
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
|
||||
: source_(source), bos_token_(bos_token), eos_token_(eos_token)
|
||||
{
|
||||
template_root_ = minja::Parser::parse(source_, {
|
||||
/* .trim_blocks = */ true,
|
||||
/* .lstrip_blocks = */ true,
|
||||
/* .keep_trailing_newline = */ false,
|
||||
});
|
||||
supports_tools_ = source.find("tools") != std::string::npos;
|
||||
|
||||
auto renders_string_arguments =
|
||||
try_render({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "Hey"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"tool_calls", json::array({
|
||||
{
|
||||
{"id", "call_1___"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"arguments", "{\"code\": \"print('Hello, World!')\"}"},
|
||||
{"name", "ipython"},
|
||||
}},
|
||||
},
|
||||
})},
|
||||
}
|
||||
}, {}, false).find("{\"code\": \"print") != std::string::npos;
|
||||
if (!renders_string_arguments) {
|
||||
auto renders_object_arguments =
|
||||
try_render({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "Hey"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"tool_calls", json::array({
|
||||
{
|
||||
{"id", "call_1___"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"arguments", {
|
||||
{"code", "print('Hello, World!')"},
|
||||
}},
|
||||
{"name", "ipython"},
|
||||
}},
|
||||
},
|
||||
})},
|
||||
}
|
||||
}, {}, false).find("{\"code\": \"print") != std::string::npos;
|
||||
requires_object_arguments_ = renders_object_arguments;
|
||||
}
|
||||
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;
|
||||
|
||||
supports_system_role_ = try_render({
|
||||
{{"role", "system"}, {"content", "<System Needle>"}},
|
||||
{{"role", "user"}, {"content", "Hey"}}
|
||||
}, {}, false).find("<System Needle>") != std::string::npos;
|
||||
}
|
||||
|
||||
const std::string & source() const { return source_; }
|
||||
const std::string & bos_token() const { return bos_token_; }
|
||||
const std::string & eos_token() const { return eos_token_; }
|
||||
bool supports_tools() const { return supports_tools_; }
|
||||
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
|
||||
|
||||
std::string apply(
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
|
||||
{
|
||||
json actual_messages;
|
||||
|
||||
// First, "fix" messages so they have a chance to be rendered correctly by the template
|
||||
|
||||
if (requires_object_arguments_ || !supports_system_role_ || !supports_tools_) {
|
||||
actual_messages = json::array();
|
||||
|
||||
std::string pending_system;
|
||||
auto flush_sys = [&]() {
|
||||
if (!pending_system.empty()) {
|
||||
actual_messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", pending_system},
|
||||
});
|
||||
pending_system.clear();
|
||||
}
|
||||
};
|
||||
for (const auto & message_ : messages) {
|
||||
auto message = message_;
|
||||
if (!message.contains("role") || !message.contains("content")) {
|
||||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||
}
|
||||
std::string role = message.at("role");
|
||||
|
||||
if (message.contains("tool_calls")) {
|
||||
if (requires_object_arguments_ || !supports_tools_) {
|
||||
for (auto & tool_call : message.at("tool_calls")) {
|
||||
if (tool_call["type"] == "function") {
|
||||
auto & function = tool_call.at("function");
|
||||
std::string arguments = function.at("arguments");
|
||||
function["arguments"] = json::parse(arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!supports_tools_) {
|
||||
auto content = message.at("content");
|
||||
auto tool_calls = json::array();
|
||||
for (const auto & tool_call : message.at("tool_calls")) {
|
||||
if (tool_call.at("type") != "function") {
|
||||
continue;
|
||||
}
|
||||
const auto & function = tool_call.at("function");
|
||||
auto tc = json {
|
||||
{"name", function.at("name")},
|
||||
{"arguments", function.at("arguments")},
|
||||
};
|
||||
if (tool_call.contains("id")) {
|
||||
tc["id"] = tool_call["id"];
|
||||
}
|
||||
tool_calls.push_back(tc);
|
||||
}
|
||||
auto obj = json {
|
||||
{"tool_calls", tool_calls},
|
||||
};
|
||||
if (!content.is_null() && content != "") {
|
||||
obj["content"] = content;
|
||||
}
|
||||
message["content"] = obj.dump(2);
|
||||
message.erase("tool_calls");
|
||||
}
|
||||
}
|
||||
if (!supports_tools_ && role == "tool") {
|
||||
message["role"] = "user";
|
||||
auto obj = json {
|
||||
{"tool_response", {
|
||||
{"tool", message.at("name")},
|
||||
{"content", message.at("content")},
|
||||
}},
|
||||
};
|
||||
if (message.contains("tool_call_id")) {
|
||||
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
|
||||
}
|
||||
message["content"] = obj.dump(2);
|
||||
message.erase("name");
|
||||
}
|
||||
|
||||
if (!message["content"].is_null() && !supports_system_role_) {
|
||||
std::string content = message.at("content");
|
||||
if (role == "system") {
|
||||
if (!pending_system.empty()) pending_system += "\n";
|
||||
pending_system += content;
|
||||
continue;
|
||||
} else {
|
||||
if (role == "user") {
|
||||
if (!pending_system.empty()) {
|
||||
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
|
||||
pending_system.clear();
|
||||
}
|
||||
} else {
|
||||
flush_sys();
|
||||
}
|
||||
}
|
||||
}
|
||||
actual_messages.push_back(message);
|
||||
}
|
||||
flush_sys();
|
||||
} else {
|
||||
actual_messages = messages;
|
||||
}
|
||||
|
||||
auto context = minja::Context::make(json({
|
||||
{"messages", actual_messages},
|
||||
{"add_generation_prompt", add_generation_prompt},
|
||||
{"bos_token", bos_token_},
|
||||
{"eos_token", eos_token_},
|
||||
}));
|
||||
|
||||
if (!tools.is_null()) {
|
||||
auto tools_val = minja::Value(tools);
|
||||
context->set("tools", tools_val);
|
||||
}
|
||||
if (!extra_context.is_null()) {
|
||||
for (auto & kv : extra_context.items()) {
|
||||
minja::Value val(kv.value());
|
||||
context->set(kv.key(), val);
|
||||
}
|
||||
}
|
||||
|
||||
return template_root_->render(context);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace minja
|
||||
2451
common/chat.cpp
2451
common/chat.cpp
File diff suppressed because it is too large
Load Diff
171
common/chat.h
171
common/chat.h
@@ -1,37 +1,16 @@
|
||||
// Chat support with builder pattern for llama.cpp compatibility
|
||||
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include <functional>
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
// Forward declarations
|
||||
struct common_chat_templates;
|
||||
|
||||
// Basic data structures compatible with original llama.cpp
|
||||
struct common_string_range {
|
||||
size_t begin;
|
||||
size_t end;
|
||||
|
||||
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
|
||||
if (begin > end) {
|
||||
throw std::runtime_error("Invalid range");
|
||||
}
|
||||
}
|
||||
|
||||
// prevent default ctor
|
||||
common_string_range() = delete;
|
||||
|
||||
bool empty() const {
|
||||
return begin == end;
|
||||
}
|
||||
|
||||
bool operator==(const common_string_range & other) const {
|
||||
return begin == other.begin && end == other.end;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_tool_call {
|
||||
std::string name;
|
||||
std::string arguments;
|
||||
@@ -40,10 +19,6 @@ struct common_chat_tool_call {
|
||||
bool operator==(const common_chat_tool_call & other) const {
|
||||
return name == other.name && arguments == other.arguments && id == other.id;
|
||||
}
|
||||
|
||||
bool operator!=(const common_chat_tool_call & other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_content_part {
|
||||
@@ -64,11 +39,11 @@ struct common_chat_msg {
|
||||
std::string tool_name;
|
||||
std::string tool_call_id;
|
||||
|
||||
bool empty() const {
|
||||
return content.empty() && content_parts.empty() && tool_calls.empty() &&
|
||||
reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||
}
|
||||
template <class T> T to_json_oaicompat() const;
|
||||
|
||||
bool empty() const {
|
||||
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||
}
|
||||
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
||||
for (auto i = 0u; i < tool_calls.size(); i++) {
|
||||
if (ids_cache.size() <= i) {
|
||||
@@ -81,7 +56,6 @@ struct common_chat_msg {
|
||||
tool_calls[i].id = ids_cache[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool operator==(const common_chat_msg & other) const {
|
||||
return role == other.role
|
||||
&& content == other.content
|
||||
@@ -91,7 +65,6 @@ struct common_chat_msg {
|
||||
&& tool_name == other.tool_name
|
||||
&& tool_call_id == other.tool_call_id;
|
||||
}
|
||||
|
||||
bool operator!=(const common_chat_msg & other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
@@ -110,10 +83,6 @@ struct common_chat_msg_diff {
|
||||
&& tool_call_index == other.tool_call_index
|
||||
&& tool_call_delta == other.tool_call_delta;
|
||||
}
|
||||
|
||||
bool operator!=(const common_chat_msg_diff & other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_tool {
|
||||
@@ -131,50 +100,110 @@ enum common_chat_tool_choice {
|
||||
enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
COMMON_CHAT_FORMAT_GENERIC,
|
||||
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||
COMMON_CHAT_FORMAT_GRANITE,
|
||||
COMMON_CHAT_FORMAT_GPT_OSS,
|
||||
COMMON_CHAT_FORMAT_KIMI_K2, // Our custom format (keep last for backward compatibility)
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
enum common_reasoning_format {
|
||||
COMMON_REASONING_FORMAT_NONE,
|
||||
COMMON_REASONING_FORMAT_AUTO,
|
||||
COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY,
|
||||
struct common_chat_templates_inputs {
|
||||
std::vector<common_chat_msg> messages;
|
||||
std::string grammar;
|
||||
std::string json_schema;
|
||||
bool add_generation_prompt = true;
|
||||
bool use_jinja = true;
|
||||
// Parameters below only supported when use_jinja is true
|
||||
std::vector<common_chat_tool> tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
bool parallel_tool_calls = false;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
std::map<std::string, std::string> chat_template_kwargs;
|
||||
bool add_bos = false;
|
||||
bool add_eos = false;
|
||||
};
|
||||
|
||||
struct common_chat_params {
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
std::string prompt;
|
||||
std::string grammar;
|
||||
bool grammar_lazy = false;
|
||||
bool thinking_forced_open = false;
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
};
|
||||
|
||||
struct common_chat_syntax {
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_KIMI_K2;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; //COMMON_REASONING_FORMAT_NONE;
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
bool enable_thinking = false;
|
||||
bool enable_tool_calls = true;
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
bool parse_tool_calls = true;
|
||||
};
|
||||
|
||||
// Exception for partial parsing
|
||||
class common_chat_msg_partial_exception : public std::runtime_error {
|
||||
public:
|
||||
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
|
||||
};
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
||||
|
||||
// Bridge functions to integrate with existing ik_llama.cpp system
|
||||
// TODO: Uncomment and implement during integration phase
|
||||
// common_chat_msg ik_to_common_msg(const struct ik_chat_msg & ik_msg);
|
||||
// struct ik_chat_msg common_to_ik_msg(const common_chat_msg & common_msg);
|
||||
void common_chat_templates_free(struct common_chat_templates * tmpls);
|
||||
|
||||
// Format detection from chat template
|
||||
common_chat_format common_chat_format_detect(const std::string & chat_template);
|
||||
const char* common_chat_format_name(common_chat_format format);
|
||||
const char* common_reasoning_format_name(common_reasoning_format format);
|
||||
struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
|
||||
|
||||
// Main parsing function (entry point for original llama.cpp compatibility)
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
|
||||
|
||||
// Forward declare parser class
|
||||
class common_chat_msg_parser;
|
||||
common_chat_templates_ptr common_chat_templates_init(
|
||||
const struct llama_model * model,
|
||||
const std::string & chat_template_override,
|
||||
const std::string & bos_token_override = "",
|
||||
const std::string & eos_token_override = "");
|
||||
|
||||
// Format-specific parsing functions (accessible from chat-parser)
|
||||
void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder);
|
||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
|
||||
|
||||
|
||||
struct common_chat_params common_chat_templates_apply(
|
||||
const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs);
|
||||
|
||||
// Format single message, while taking into account the position of that message in chat history
|
||||
std::string common_chat_format_single(
|
||||
const struct common_chat_templates * tmpls,
|
||||
const std::vector<common_chat_msg> & past_msg,
|
||||
const common_chat_msg & new_msg,
|
||||
bool add_ass,
|
||||
bool use_jinja);
|
||||
|
||||
// Returns an example of formatted chat
|
||||
std::string common_chat_format_example(
|
||||
const struct common_chat_templates * tmpls,
|
||||
bool use_jinja);
|
||||
|
||||
const char* common_chat_format_name(common_chat_format format);
|
||||
const char* common_reasoning_format_name(common_reasoning_format format);
|
||||
common_reasoning_format common_reasoning_format_from_name(const std::string& format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||
|
||||
// Parses a JSON array of messages in OpenAI's chat completion API format.
|
||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
|
||||
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||
|
||||
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
|
||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
||||
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||
|
||||
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
|
||||
@@ -13,10 +13,10 @@
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama-vocab.h"
|
||||
#include "llama.h"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
#include "chat.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
#include <climits>
|
||||
@@ -230,13 +230,13 @@ void gpt_params_handle_model_default(gpt_params & params) {
|
||||
}
|
||||
params.hf_file = params.model;
|
||||
} else if (params.model.empty()) {
|
||||
params.model = fs_get_cache_file(string_split(params.hf_file, '/').back());
|
||||
params.model = fs_get_cache_file(string_split(params.hf_file, "/").back());
|
||||
}
|
||||
} else if (!params.model_url.empty()) {
|
||||
if (params.model.empty()) {
|
||||
auto f = string_split(params.model_url, '#').front();
|
||||
f = string_split(f, '?').front();
|
||||
params.model = fs_get_cache_file(string_split(f, '/').back());
|
||||
auto f = string_split(params.model_url, "#").front();
|
||||
f = string_split(f, "?").front();
|
||||
params.model = fs_get_cache_file(string_split(f, "/").back());
|
||||
}
|
||||
} else if (params.model.empty()) {
|
||||
params.model = DEFAULT_MODEL_PATH;
|
||||
@@ -295,7 +295,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
|
||||
if (!params.chat_template.empty() && !llama_chat_verify_template(nullptr, params.chat_template, params.use_jinja)) {
|
||||
if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
|
||||
throw std::runtime_error(string_format(
|
||||
"error: the supplied chat template is not supported: %s%s\n",
|
||||
params.chat_template.c_str(),
|
||||
@@ -599,7 +599,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
}
|
||||
if (arg == "--samplers") {
|
||||
CHECK_ARG
|
||||
const auto sampler_names = string_split(argv[i], ';');
|
||||
const auto sampler_names = string_split(argv[i], ";");
|
||||
sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true);
|
||||
return true;
|
||||
}
|
||||
@@ -1486,6 +1486,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (arg == "--reasoning-budget") {
|
||||
CHECK_ARG
|
||||
params.reasoning_budget = std::stoi(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--sql-save-file") {
|
||||
CHECK_ARG
|
||||
params.sql_save_file = argv[i];
|
||||
@@ -1498,7 +1503,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
}
|
||||
if (arg == "--chat-template") {
|
||||
CHECK_ARG
|
||||
if (!llama_chat_verify_template(nullptr, argv[i], false)) {
|
||||
if (!common_chat_verify_template(argv[i], true)) {
|
||||
fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]);
|
||||
fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n");
|
||||
invalid_param = true;
|
||||
@@ -1510,9 +1515,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
if (arg == "--chat-template-file") {
|
||||
CHECK_ARG
|
||||
std::string chat_template = read_file(std::string(argv[i]));
|
||||
if (!llama_chat_verify_template(nullptr, chat_template, false)) {
|
||||
if (!common_chat_verify_template(chat_template, true)) {
|
||||
fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]);
|
||||
fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n");
|
||||
invalid_param = true;
|
||||
return true;
|
||||
}
|
||||
@@ -1523,6 +1527,26 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.use_jinja = true;
|
||||
return true;
|
||||
}
|
||||
if (arg == "--chat-template-kwargs") {
|
||||
CHECK_ARG
|
||||
std::string value = argv[i];
|
||||
auto parsed = json::parse(value);
|
||||
for (const auto& item : parsed.items()) {
|
||||
params.default_template_kwargs[item.key()] = item.value().dump();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (arg == "--reasoning-format") {
|
||||
CHECK_ARG
|
||||
std::string value = argv[i];
|
||||
params.reasoning_format = common_reasoning_format_from_name(value);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--no-prefill-assistant") {
|
||||
CHECK_ARG
|
||||
params.prefill_assistant = false;
|
||||
return true;
|
||||
}
|
||||
if (arg == "--slot-prompt-similarity" || arg == "-sps") {
|
||||
CHECK_ARG
|
||||
params.slot_prompt_similarity = std::stof(argv[i]);
|
||||
@@ -1831,11 +1855,22 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "main", " --cfg-negative-prompt-file FNAME",
|
||||
"negative prompt file to use for guidance" });
|
||||
options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale });
|
||||
options.push_back({ "main", " --chat-template JINJA_TEMPLATE",
|
||||
options.push_back({ "main", " --jinja",
|
||||
"set custom jinja chat template (default: template taken from model's metadata)\n"
|
||||
"if suffix/prefix are specified, template will be disabled\n"
|
||||
"only commonly used templates are accepted:\n"
|
||||
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
|
||||
options.push_back({ "main", " --chat-template JINJA_TEMPLATE",
|
||||
"use jinja template for chat (default: disabled)\n" });
|
||||
options.push_back({ "main", " --reasoning-format FORMAT",
|
||||
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
|
||||
"- none: leaves thoughts unparsed in `message.content`\n"
|
||||
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
|
||||
"(default: none)", });
|
||||
options.push_back({ "main", " --chat-template-kwargs JSON", "sets additional params for the json template parser"});
|
||||
options.push_back({ "main", " --reasoning-budget N", "controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)" });
|
||||
options.push_back({ "main", " --no-prefill-assistant", "whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)\n"
|
||||
"when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled\n" });
|
||||
options.push_back({ "grammar" });
|
||||
options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
|
||||
options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" });
|
||||
@@ -2095,42 +2130,66 @@ std::string string_format(const char* fmt, ...) {
|
||||
return std::string(buf.data(), size);
|
||||
}
|
||||
|
||||
std::string regex_escape(const std::string& s) {
|
||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||
return std::regex_replace(s, special_chars, "\\$0");
|
||||
}
|
||||
|
||||
std::vector<std::string> string_split(std::string input, char separator) {
|
||||
std::vector<std::string> parts;
|
||||
size_t separator_pos = input.find(separator);
|
||||
while (separator_pos != std::string::npos) {
|
||||
std::string part = input.substr(0, separator_pos);
|
||||
parts.emplace_back(part);
|
||||
input = input.substr(separator_pos + 1);
|
||||
separator_pos = input.find(separator);
|
||||
std::string string_join(const std::vector<std::string>& values, const std::string& separator) {
|
||||
std::ostringstream result;
|
||||
for (size_t i = 0; i < values.size(); ++i) {
|
||||
if (i > 0) {
|
||||
result << separator;
|
||||
}
|
||||
result << values[i];
|
||||
}
|
||||
parts.emplace_back(input);
|
||||
return result.str();
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::string> string_split(const std::string& str, const std::string& delimiter) {
|
||||
std::vector<std::string> parts;
|
||||
size_t start = 0;
|
||||
size_t end = str.find(delimiter);
|
||||
|
||||
while (end != std::string::npos) {
|
||||
parts.push_back(str.substr(start, end - start));
|
||||
start = end + delimiter.length();
|
||||
end = str.find(delimiter, start);
|
||||
}
|
||||
|
||||
parts.push_back(str.substr(start));
|
||||
|
||||
return parts;
|
||||
}
|
||||
|
||||
std::string string_join(const std::vector<std::string> & strs, const std::string & delimiter) {
|
||||
if (strs.empty()) {
|
||||
return "";
|
||||
std::vector<std::string> string_split(const std::string& str, char delim) {
|
||||
std::vector<std::string> values;
|
||||
std::istringstream str_stream(str);
|
||||
std::string token;
|
||||
while (std::getline(str_stream, token, delim)) {
|
||||
std::string value;
|
||||
std::istringstream token_stream(token);
|
||||
token_stream >> value;
|
||||
values.push_back(value);
|
||||
}
|
||||
return values;
|
||||
}
|
||||
|
||||
std::ostringstream oss;
|
||||
for (size_t i = 0; i < strs.size(); ++i) {
|
||||
if (i > 0) {
|
||||
oss << delimiter;
|
||||
}
|
||||
oss << strs[i];
|
||||
}
|
||||
return oss.str();
|
||||
static bool is_utf8_whitespace(uint8_t c) {
|
||||
// Basic ASCII whitespace
|
||||
if (c <= 0x7F) return isspace(c);
|
||||
// Else: Not whitespace (or you'd need a full Unicode table)
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string string_strip(const std::string & str) {
|
||||
size_t start = 0;
|
||||
size_t end = str.size();
|
||||
while (start < end && std::isspace(str[start])) {
|
||||
while (start < end && is_utf8_whitespace(str[start])) {
|
||||
start++;
|
||||
}
|
||||
while (end > start && std::isspace(str[end - 1])) {
|
||||
while (end > start && is_utf8_whitespace(str[end - 1])) {
|
||||
end--;
|
||||
}
|
||||
return str.substr(start, end - start);
|
||||
@@ -2163,6 +2222,25 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||
}
|
||||
}
|
||||
|
||||
bool string_ends_with(const std::string_view& str, const std::string_view& suffix) {
|
||||
return str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
size_t string_find_partial_stop(const std::string_view& str, const std::string_view& stop) {
|
||||
if (!str.empty() && !stop.empty()) {
|
||||
const char text_last_char = str.back();
|
||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||
if (stop[char_index] == text_last_char) {
|
||||
const auto current_partial = stop.substr(0, char_index + 1);
|
||||
if (string_ends_with(str, current_partial)) {
|
||||
return str.size() - char_index - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
void string_process_escapes(std::string & input) {
|
||||
std::size_t input_len = input.length();
|
||||
std::size_t output_idx = 0;
|
||||
@@ -3140,154 +3218,172 @@ bool llama_should_add_bos_token(const llama_model * model) {
|
||||
//
|
||||
// Chat template utils
|
||||
//
|
||||
//
|
||||
//bool llama_chat_verify_template(const struct llama_model* model, const std::string& tmpl, bool use_jinja) {
|
||||
// if (use_jinja) {
|
||||
// try {
|
||||
// auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
|
||||
// common_chat_inputs inputs;
|
||||
// inputs.messages = json::array({ {
|
||||
// {"role", "user"},
|
||||
// {"content", "test"},
|
||||
// } });
|
||||
// common_chat_params_init(chat_template, inputs);
|
||||
// return true;
|
||||
// }
|
||||
// catch (const std::exception& e) {
|
||||
// fprintf(stdout,"%s: failed to apply template: %s\n", __func__, e.what());
|
||||
// return false;
|
||||
// }
|
||||
// }
|
||||
// llama_chat_message chat[] = { {"user", "test"} };
|
||||
// const int res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||
// return res >= 0;
|
||||
//}
|
||||
|
||||
bool llama_chat_verify_template(const struct llama_model* model, const std::string& tmpl, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
try {
|
||||
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
|
||||
chat_template.apply({ {
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
} }, json(), true);
|
||||
return true;
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
fprintf(stdout,"%s: failed to apply template: %s\n", __func__, e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
const int res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||
return res >= 0;
|
||||
}
|
||||
//std::string llama_chat_apply_template(const struct llama_model * model,
|
||||
// const common_chat_template& tmpl,
|
||||
// const std::vector<common_chat_msg> & msgs,
|
||||
// bool add_ass,
|
||||
// bool use_jinja) {
|
||||
// if (use_jinja) {
|
||||
// auto messages = json::array();
|
||||
// for (const auto& msg : msgs) {
|
||||
// messages.push_back({ {"role", msg.role}, {"content", msg.content} });
|
||||
// }
|
||||
// common_chat_inputs inputs;
|
||||
// inputs.messages = messages;
|
||||
// inputs.add_generation_prompt = add_ass;
|
||||
// return common_chat_params_init(tmpl, inputs).prompt;
|
||||
// }
|
||||
// int alloc_size = 0;
|
||||
// std::vector<llama_chat_message> chat;
|
||||
// for (auto & msg : msgs) {
|
||||
// chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
||||
// alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
||||
// }
|
||||
//
|
||||
// std::vector<char> buf(alloc_size);
|
||||
//
|
||||
// // run the first time to get the total output length
|
||||
// int32_t res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
// // error: chat template is not supported
|
||||
// if (res < 0) {
|
||||
// // if the custom "tmpl" is not supported, we throw an error
|
||||
// // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
||||
// throw std::runtime_error("this custom template is not supported");
|
||||
// }
|
||||
//
|
||||
// // if it turns out that our buffer is too small, we resize it
|
||||
// if ((size_t)res > buf.size()) {
|
||||
// buf.resize(res);
|
||||
// res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
// }
|
||||
//
|
||||
// std::string formatted_chat(buf.data(), res);
|
||||
// return formatted_chat;
|
||||
//}
|
||||
////
|
||||
//std::string llama_chat_format_single(const struct llama_model * model,
|
||||
// const common_chat_template& tmpl,
|
||||
// const std::vector<common_chat_msg> & past_msg,
|
||||
// const common_chat_msg & new_msg,
|
||||
// bool add_ass,
|
||||
// bool use_jinja) {
|
||||
// std::ostringstream ss;
|
||||
// auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja);
|
||||
// std::vector<common_chat_msg> chat_new(past_msg);
|
||||
// // if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||
// if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||
// ss << "\n";
|
||||
// };
|
||||
// // format chat with new_msg
|
||||
// chat_new.push_back(new_msg);
|
||||
// auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja);
|
||||
// // get the diff part
|
||||
// ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||
// return ss.str();
|
||||
//}
|
||||
|
||||
std::string llama_chat_apply_template(const struct llama_model * model,
|
||||
const common_chat_template& tmpl,
|
||||
const std::vector<llama_chat_msg> & msgs,
|
||||
bool add_ass,
|
||||
bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
auto messages = json::array();
|
||||
for (const auto& msg : msgs) {
|
||||
messages.push_back({ {"role", msg.role}, {"content", msg.content} });
|
||||
}
|
||||
return tmpl.apply(messages, /* tools= */ json(), add_ass);
|
||||
}
|
||||
int alloc_size = 0;
|
||||
std::vector<llama_chat_message> chat;
|
||||
for (auto & msg : msgs) {
|
||||
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
||||
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
||||
}
|
||||
|
||||
std::vector<char> buf(alloc_size);
|
||||
|
||||
// run the first time to get the total output length
|
||||
int32_t res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
// error: chat template is not supported
|
||||
if (res < 0) {
|
||||
// if the custom "tmpl" is not supported, we throw an error
|
||||
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
||||
throw std::runtime_error("this custom template is not supported");
|
||||
}
|
||||
|
||||
// if it turns out that our buffer is too small, we resize it
|
||||
if ((size_t) res > buf.size()) {
|
||||
buf.resize(res);
|
||||
res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
}
|
||||
|
||||
std::string formatted_chat(buf.data(), res);
|
||||
return formatted_chat;
|
||||
}
|
||||
|
||||
std::string llama_chat_format_single(const struct llama_model * model,
|
||||
const common_chat_template& tmpl,
|
||||
const std::vector<llama_chat_msg> & past_msg,
|
||||
const llama_chat_msg & new_msg,
|
||||
bool add_ass,
|
||||
bool use_jinja) {
|
||||
std::ostringstream ss;
|
||||
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja);
|
||||
std::vector<llama_chat_msg> chat_new(past_msg);
|
||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||
ss << "\n";
|
||||
};
|
||||
// format chat with new_msg
|
||||
chat_new.push_back(new_msg);
|
||||
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja);
|
||||
// get the diff part
|
||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string llama_chat_format_example(const struct llama_model * model, const common_chat_template& tmpl, bool use_jinja) {
|
||||
std::vector<llama_chat_msg> msgs = {
|
||||
{"system", "You are a helpful assistant"},
|
||||
{"user", "Hello"},
|
||||
{"assistant", "Hi there"},
|
||||
{"user", "How are you?"},
|
||||
};
|
||||
return llama_chat_apply_template(model, tmpl, msgs, true, use_jinja);
|
||||
}
|
||||
|
||||
|
||||
common_chat_templates llama_chat_templates_from_model(const struct llama_model* model, const std::string& chat_template_override)
|
||||
{
|
||||
auto vocab = llama_model_get_vocab(model);
|
||||
std::string default_template_src = chat_template_override;
|
||||
std::string template_tool_use_src = chat_template_override;
|
||||
bool has_explicit_template = !chat_template_override.empty();
|
||||
if (chat_template_override.empty()) {
|
||||
auto str = llama_model_chat_template(model, /* name */ nullptr);
|
||||
if (str) {
|
||||
default_template_src = str;
|
||||
has_explicit_template = true;
|
||||
}
|
||||
str = llama_model_chat_template(model, /* name */ "tool_use");
|
||||
if (str) {
|
||||
template_tool_use_src = str;
|
||||
has_explicit_template = true;
|
||||
}
|
||||
}
|
||||
if (default_template_src.empty() || default_template_src == "chatml") {
|
||||
if (!template_tool_use_src.empty()) {
|
||||
default_template_src = template_tool_use_src;
|
||||
}
|
||||
else {
|
||||
default_template_src = R"(
|
||||
{%- for message in messages -%}
|
||||
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- "<|im_start|>assistant\n" -}}
|
||||
{%- endif -%}
|
||||
)";
|
||||
}
|
||||
}
|
||||
const auto get_token = [&](llama_token token, const char* name, const char* jinja_variable_name) {
|
||||
if (token == LLAMA_TOKEN_NULL) {
|
||||
if (default_template_src.find(jinja_variable_name) != std::string::npos
|
||||
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
||||
fprintf(stdout, "%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
|
||||
}
|
||||
return std::string();
|
||||
}
|
||||
else {
|
||||
return llama_token_to_piece(model, token, true);
|
||||
}
|
||||
};
|
||||
auto token_bos = get_token(llama_token_bos(model), "BOS", "bos_token");
|
||||
auto token_eos = get_token(llama_token_eos(model), "EOS", "eos_token");
|
||||
return {
|
||||
has_explicit_template,
|
||||
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
|
||||
template_tool_use_src.empty()
|
||||
? nullptr
|
||||
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos)
|
||||
};
|
||||
}
|
||||
//std::string llama_chat_format_example(const struct llama_model * model, const common_chat_template& tmpl, bool use_jinja) {
|
||||
// std::vector<common_chat_msg> msgs = {
|
||||
// {"system", "You are a helpful assistant", {}},
|
||||
// {"user", "Hello", {}},
|
||||
// {"assistant", "Hi there", {}},
|
||||
// {"user", "How are you?", {}},
|
||||
// };
|
||||
// return llama_chat_apply_template(model, tmpl, msgs, true, use_jinja);
|
||||
//}
|
||||
//
|
||||
//#define CHATML_TEMPLATE_SRC \
|
||||
// "{%- for message in messages -%}\n" \
|
||||
// " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
|
||||
// "{%- endfor -%}\n" \
|
||||
// "{%- if add_generation_prompt -%}\n" \
|
||||
// " {{- '<|im_start|>assistant\n' -}}\n" \
|
||||
// "{%- endif -%}"
|
||||
//
|
||||
//common_chat_templates llama_chat_templates_from_model(const struct llama_model* model, const std::string& chat_template_override)
|
||||
//{
|
||||
// std::string default_template_src;
|
||||
// std::string template_tool_use_src;
|
||||
// bool has_explicit_template = !chat_template_override.empty();
|
||||
// if (chat_template_override.empty()) {
|
||||
// auto str = llama_model_chat_template(model, /* name */ nullptr);
|
||||
// if (str) {
|
||||
// default_template_src = str;
|
||||
// has_explicit_template = true;
|
||||
// }
|
||||
// str = llama_model_chat_template(model, /* name */ "tool_use");
|
||||
// if (str) {
|
||||
// template_tool_use_src = str;
|
||||
// has_explicit_template = true;
|
||||
// }
|
||||
// }
|
||||
// else {
|
||||
// default_template_src = chat_template_override;
|
||||
// }
|
||||
// if (default_template_src.empty() || default_template_src == "chatml") {
|
||||
// if (!template_tool_use_src.empty()) {
|
||||
// default_template_src = template_tool_use_src;
|
||||
// }
|
||||
// else {
|
||||
// default_template_src = CHATML_TEMPLATE_SRC;
|
||||
// }
|
||||
// }
|
||||
// auto vocab = llama_model_get_vocab(model);
|
||||
// const auto get_token = [&](llama_token token, const char* name, const char* jinja_variable_name) {
|
||||
// if (token == LLAMA_TOKEN_NULL) {
|
||||
// if (default_template_src.find(jinja_variable_name) != std::string::npos
|
||||
// || template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
||||
// fprintf(stdout, "%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
|
||||
// }
|
||||
// return std::string();
|
||||
// }
|
||||
// else {
|
||||
// return llama_token_to_piece(model, token, true);
|
||||
// }
|
||||
// };
|
||||
// auto token_bos = get_token(llama_token_bos_impl(*vocab), "BOS", "bos_token");
|
||||
// auto token_eos = get_token(llama_token_eos_impl(*vocab), "EOS", "eos_token");
|
||||
// try {
|
||||
// return {
|
||||
// has_explicit_template,
|
||||
// std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
|
||||
// template_tool_use_src.empty()
|
||||
// ? nullptr
|
||||
// : std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
|
||||
// };
|
||||
// }
|
||||
// catch (const std::exception& e) {
|
||||
// LOG("%s: failed to parse chat template: %s\n", __func__, e.what());
|
||||
// return {
|
||||
// has_explicit_template,
|
||||
// std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
|
||||
// nullptr,
|
||||
// };
|
||||
// }
|
||||
//}
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
@@ -3778,27 +3874,3 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
||||
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
|
||||
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
|
||||
}
|
||||
|
||||
// Additional string utilities for builder pattern compatibility
|
||||
bool string_starts_with(const std::string & str, const std::string & prefix) {
|
||||
return str.rfind(prefix, 0) == 0;
|
||||
}
|
||||
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
|
||||
if (!str.empty() && !stop.empty()) {
|
||||
const char text_last_char = str.back();
|
||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||
if (stop[char_index] == text_last_char) {
|
||||
const auto current_partial = stop.substr(0, char_index + 1);
|
||||
if (string_ends_with(str, current_partial)) {
|
||||
return str.size() - char_index - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
160
common/common.h
160
common/common.h
@@ -15,14 +15,18 @@
|
||||
|
||||
#define LOG_NO_FILE_LINE_FUNCTION
|
||||
#include "log.h"
|
||||
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <tuple>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
|
||||
#ifdef _WIN32
|
||||
#define DIRECTORY_SEPARATOR '\\'
|
||||
@@ -74,6 +78,14 @@ enum dimre_method {
|
||||
DIMRE_METHOD_MEAN,
|
||||
};
|
||||
|
||||
// reasoning API response format (not to be confused as chat template's reasoning format)
|
||||
enum common_reasoning_format {
|
||||
COMMON_REASONING_FORMAT_NONE,
|
||||
COMMON_REASONING_FORMAT_AUTO,
|
||||
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
|
||||
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
|
||||
};
|
||||
|
||||
struct gpt_params {
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed
|
||||
|
||||
@@ -240,13 +252,21 @@ struct gpt_params {
|
||||
bool use_jinja = false; // NOLINT
|
||||
std::string system_prompt = "";
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
int reasoning_budget = -1;
|
||||
bool prefill_assistant = true;
|
||||
|
||||
std::vector<std::string> api_keys;
|
||||
|
||||
std::string ssl_file_key = "";
|
||||
std::string ssl_file_cert = "";
|
||||
|
||||
bool endpoint_slots = true;
|
||||
std::map<std::string, std::string> default_template_kwargs;
|
||||
|
||||
// "advanced" endpoints are disabled by default for better security
|
||||
bool webui = true;
|
||||
bool endpoint_slots = false;
|
||||
bool endpoint_props = false; // only control POST requests, not GET
|
||||
bool endpoint_metrics = false;
|
||||
|
||||
bool log_json = false;
|
||||
@@ -314,19 +334,24 @@ std::string gpt_params_get_system_info(const gpt_params & params);
|
||||
//
|
||||
// String utils
|
||||
//
|
||||
|
||||
std::vector<std::string> string_split(std::string input, char separator);
|
||||
std::string string_join(const std::vector<std::string> & strs, const std::string & delimiter);
|
||||
|
||||
std::string string_join(const std::vector<std::string>& values, const std::string& separator);
|
||||
std::string string_strip(const std::string & str);
|
||||
std::string string_get_sortable_timestamp();
|
||||
|
||||
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
|
||||
static bool string_starts_with(const std::string& str,
|
||||
const std::string& prefix) { // While we wait for C++20's std::string::starts_with...
|
||||
return str.rfind(prefix, 0) == 0;
|
||||
}
|
||||
|
||||
// Additional string utilities for builder pattern compatibility
|
||||
bool string_starts_with(const std::string & str, const std::string & prefix);
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
|
||||
std::vector<std::string> string_split(const std::string& str, const std::string& delimiter);
|
||||
std::vector<std::string> string_split(const std::string& str, char delim);
|
||||
|
||||
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
|
||||
// While we wait for C++20's std::string::ends_with...
|
||||
bool string_ends_with(const std::string_view& str, const std::string_view& suffix);
|
||||
size_t string_find_partial_stop(const std::string_view& str, const std::string_view& stop);
|
||||
|
||||
std::string regex_escape(const std::string& s);
|
||||
|
||||
template<class T>
|
||||
static std::vector<T> string_split(const std::string & str, char delim) {
|
||||
@@ -342,6 +367,22 @@ static std::vector<T> string_split(const std::string & str, char delim) {
|
||||
return values;
|
||||
}
|
||||
|
||||
template<>
|
||||
std::vector<std::string> string_split<std::string>(const std::string& input, char separator)
|
||||
{
|
||||
std::vector<std::string> parts;
|
||||
size_t begin_pos = 0;
|
||||
size_t separator_pos = input.find(separator);
|
||||
while (separator_pos != std::string::npos) {
|
||||
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
|
||||
parts.emplace_back(part);
|
||||
begin_pos = separator_pos + 1;
|
||||
separator_pos = input.find(separator, begin_pos);
|
||||
}
|
||||
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
|
||||
return parts;
|
||||
}
|
||||
|
||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||
void string_process_escapes(std::string & input);
|
||||
|
||||
@@ -432,52 +473,59 @@ bool llama_should_add_bos_token(const llama_model * model);
|
||||
//
|
||||
// Chat template utils
|
||||
//
|
||||
//struct common_tool_call {
|
||||
// std::string name;
|
||||
// std::string arguments;
|
||||
// std::string id;
|
||||
//};
|
||||
//
|
||||
//// same with llama_chat_message, but uses std::string
|
||||
//struct common_chat_msg {
|
||||
// std::string role;
|
||||
// std::string content;
|
||||
// std::vector<common_tool_call> tool_calls;
|
||||
// std::string reasoning_content = "";
|
||||
//};
|
||||
|
||||
// same with llama_chat_message, but uses std::string
|
||||
struct llama_chat_msg {
|
||||
std::string role;
|
||||
std::string content;
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
bool llama_chat_verify_template(const struct llama_model* , const std::string& tmpl, bool use_jinja);
|
||||
|
||||
namespace minja {
|
||||
class chat_template;
|
||||
}
|
||||
|
||||
typedef minja::chat_template common_chat_template;
|
||||
|
||||
struct common_chat_templates {
|
||||
bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
||||
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
||||
std::unique_ptr<common_chat_template> template_tool_use;
|
||||
};
|
||||
|
||||
|
||||
// CPP wrapper for llama_chat_apply_template
|
||||
// If the built-in template is not supported, we default to chatml
|
||||
// If the custom "tmpl" is not supported, we throw an error
|
||||
std::string llama_chat_apply_template(
|
||||
const struct llama_model* model,
|
||||
const common_chat_template& tmpl,
|
||||
const std::vector< llama_chat_msg>& chat,
|
||||
bool add_ass,
|
||||
bool use_jinja);
|
||||
|
||||
// Format single message, while taking into account the position of that message in chat history
|
||||
std::string llama_chat_format_single(const struct llama_model* model,
|
||||
const common_chat_template& tmpl,
|
||||
const std::vector< llama_chat_msg>& past_msg,
|
||||
const llama_chat_msg& new_msg,
|
||||
bool add_ass,
|
||||
bool use_jinja);
|
||||
|
||||
// Returns an example of formatted chat
|
||||
std::string llama_chat_format_example(const struct llama_model* model,
|
||||
const common_chat_template& tmpl, bool use_jinja);
|
||||
|
||||
common_chat_templates llama_chat_templates_from_model(const struct llama_model* model, const std::string& chat_template_override);
|
||||
//// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
//bool llama_chat_verify_template(const struct llama_model* , const std::string& tmpl, bool use_jinja);
|
||||
//
|
||||
//namespace minja {
|
||||
// class chat_template;
|
||||
//}
|
||||
//
|
||||
//typedef minja::chat_template common_chat_template;
|
||||
//
|
||||
//struct common_chat_templates {
|
||||
// bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
||||
// std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
||||
// std::unique_ptr<common_chat_template> template_tool_use;
|
||||
//};
|
||||
//
|
||||
//
|
||||
//// CPP wrapper for llama_chat_apply_template
|
||||
//// If the built-in template is not supported, we default to chatml
|
||||
//// If the custom "tmpl" is not supported, we throw an error
|
||||
//std::string llama_chat_apply_template(
|
||||
// const struct llama_model* model,
|
||||
// const common_chat_template& tmpl,
|
||||
// const std::vector< common_chat_msg>& chat,
|
||||
// bool add_ass,
|
||||
// bool use_jinja);
|
||||
//
|
||||
//// Format single message, while taking into account the position of that message in chat history
|
||||
//std::string llama_chat_format_single(const struct llama_model* model,
|
||||
// const common_chat_template& tmpl,
|
||||
// const std::vector< common_chat_msg>& past_msg,
|
||||
// const common_chat_msg& new_msg,
|
||||
// bool add_ass,
|
||||
// bool use_jinja);
|
||||
//
|
||||
//// Returns an example of formatted chat
|
||||
//std::string llama_chat_format_example(const struct llama_model* model,
|
||||
// const common_chat_template& tmpl, bool use_jinja);
|
||||
//
|
||||
//common_chat_templates llama_chat_templates_from_model(const struct llama_model* model, const std::string& chat_template_override);
|
||||
|
||||
|
||||
//
|
||||
|
||||
@@ -383,10 +383,13 @@ namespace grammar_parser {
|
||||
}
|
||||
}
|
||||
}
|
||||
state.success = true;
|
||||
return state;
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
||||
return parse_state();
|
||||
fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
|
||||
parse_state state;
|
||||
state.success = false;
|
||||
return state;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ namespace grammar_parser {
|
||||
std::vector<std::vector<llama_grammar_element>> rules;
|
||||
|
||||
std::vector<const llama_grammar_element *> c_rules();
|
||||
bool success;
|
||||
};
|
||||
|
||||
parse_state parse(const char * src);
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
#include "json-partial.h"
|
||||
|
||||
#include <json-partial.h>
|
||||
#include "ggml.h"
|
||||
#include "log.h"
|
||||
#include "../ggml/include/ggml.h"
|
||||
#include "../examples/server/utils.hpp"
|
||||
|
||||
#include "json.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <json.hpp>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum common_json_stack_element_type {
|
||||
@@ -129,7 +126,7 @@ bool common_json_parse(
|
||||
return true;
|
||||
} catch (const std::exception & ex) {
|
||||
// No, needs healing.
|
||||
LOG_VERBOSE("Failed to parse up to error", {{"error", ex.what()}, {"content", std::string(it, temptative_end)}});
|
||||
LOG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
|
||||
}
|
||||
auto can_parse = [](const std::string & str) {
|
||||
try {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "json.hpp"
|
||||
#include <json.hpp>
|
||||
|
||||
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
|
||||
struct common_healing_marker {
|
||||
|
||||
@@ -267,7 +267,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
||||
throw std::runtime_error("At least one of min_value or max_value must be set");
|
||||
}
|
||||
|
||||
const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
|
||||
const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}";
|
||||
|
||||
struct BuiltinRule {
|
||||
std::string content;
|
||||
@@ -389,6 +389,7 @@ static std::string format_literal(const std::string & literal) {
|
||||
|
||||
class SchemaConverter {
|
||||
private:
|
||||
friend std::string build_grammar(const std::function<void(const common_grammar_builder&)>& cb, const common_grammar_options& options);
|
||||
std::function<json(const std::string &)> _fetch_json;
|
||||
bool _dotall;
|
||||
std::map<std::string, std::string> _rules;
|
||||
@@ -1035,11 +1036,35 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
std::string json_schema_to_grammar(const json & schema) {
|
||||
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
|
||||
auto copy = schema;
|
||||
converter.resolve_refs(copy, "input");
|
||||
converter.visit(copy, "");
|
||||
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
if (!force_gbnf) {
|
||||
return "%llguidance {}\nstart: %json " + schema.dump();
|
||||
}
|
||||
#else
|
||||
(void)force_gbnf;
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
return build_grammar([&](const common_grammar_builder& callbacks) {
|
||||
auto copy = schema;
|
||||
callbacks.resolve_refs(copy);
|
||||
callbacks.add_schema("", copy);
|
||||
});
|
||||
}
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder&)>& cb, const common_grammar_options& options) {
|
||||
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
|
||||
common_grammar_builder builder{
|
||||
/* .add_rule = */ [&](const std::string& name, const std::string& rule) {
|
||||
return converter._add_rule(name, rule);
|
||||
},
|
||||
/* .add_schema = */ [&](const std::string& name, const nlohmann::ordered_json& schema) {
|
||||
return converter.visit(schema, name == "root" ? "" : name);
|
||||
},
|
||||
/* .resolve_refs = */ [&](nlohmann::ordered_json& schema) {
|
||||
converter.resolve_refs(schema, "");
|
||||
}
|
||||
};
|
||||
cb(builder);
|
||||
converter.check_errors();
|
||||
return converter.format_grammar();
|
||||
}
|
||||
|
||||
@@ -5,4 +5,17 @@
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
|
||||
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
|
||||
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
||||
bool force_gbnf = false);
|
||||
|
||||
struct common_grammar_builder {
|
||||
std::function<std::string(const std::string&, const std::string&)> add_rule;
|
||||
std::function<std::string(const std::string&, const nlohmann::ordered_json&)> add_schema;
|
||||
std::function<void(nlohmann::ordered_json&)> resolve_refs;
|
||||
};
|
||||
|
||||
struct common_grammar_options {
|
||||
bool dotall = false;
|
||||
};
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder&)>& cb, const common_grammar_options& options = {});
|
||||
|
||||
270
common/llguidance.cpp
Normal file
270
common/llguidance.cpp
Normal file
@@ -0,0 +1,270 @@
|
||||
#include "sampling.h"
|
||||
#include "log.h"
|
||||
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
|
||||
# include "llguidance.h"
|
||||
# include <cmath>
|
||||
|
||||
struct llama_sampler_llg {
|
||||
const llama_vocab * vocab;
|
||||
std::string grammar_kind;
|
||||
std::string grammar_data;
|
||||
LlgTokenizer * tokenizer;
|
||||
LlgConstraint * grammar;
|
||||
LlgMaskResult llg_res;
|
||||
bool has_llg_res;
|
||||
};
|
||||
|
||||
static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
|
||||
const char * grammar_data) {
|
||||
LlgConstraintInit cinit;
|
||||
llg_constraint_init_set_defaults(&cinit, tokenizer);
|
||||
const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
|
||||
if (log_level && *log_level) {
|
||||
cinit.log_stderr_level = atoi(log_level);
|
||||
}
|
||||
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
|
||||
if (llg_get_error(c)) {
|
||||
LOG_ERR("llg error: %s\n", llg_get_error(c));
|
||||
llg_free_constraint(c);
|
||||
return nullptr;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
|
||||
return "llguidance";
|
||||
}
|
||||
|
||||
static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
|
||||
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
||||
if (ctx->grammar) {
|
||||
LlgCommitResult res;
|
||||
llg_commit_token(ctx->grammar, token, &res);
|
||||
ctx->has_llg_res = false;
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
||||
if (ctx->grammar) {
|
||||
if (!ctx->has_llg_res) {
|
||||
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
|
||||
ctx->has_llg_res = true;
|
||||
} else {
|
||||
LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
|
||||
llg_free_constraint(ctx->grammar);
|
||||
ctx->grammar = nullptr;
|
||||
}
|
||||
}
|
||||
if (ctx->has_llg_res) {
|
||||
if (ctx->llg_res.is_stop) {
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const uint32_t * mask = ctx->llg_res.sample_mask;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
auto token = cur_p->data[i].id;
|
||||
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_sampler_llg_reset(llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
||||
if (!ctx->grammar) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
|
||||
llg_free_constraint(ctx->grammar);
|
||||
ctx->grammar = grammar_new;
|
||||
ctx->has_llg_res = false;
|
||||
}
|
||||
|
||||
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
|
||||
|
||||
auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
|
||||
|
||||
// copy the state
|
||||
{
|
||||
auto * result_ctx = (llama_sampler_llg *) result->ctx;
|
||||
|
||||
if (ctx->grammar) {
|
||||
result_ctx->grammar_kind = ctx->grammar_kind;
|
||||
result_ctx->grammar_data = ctx->grammar_data;
|
||||
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
|
||||
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static void llama_sampler_llg_free(llama_sampler * smpl) {
|
||||
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
||||
|
||||
if (ctx->grammar) {
|
||||
llg_free_constraint(ctx->grammar);
|
||||
llg_free_tokenizer(ctx->tokenizer);
|
||||
}
|
||||
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
static llama_sampler_i llama_sampler_llg_i = {
|
||||
/* .name = */ llama_sampler_llg_name,
|
||||
/* .accept = */ llama_sampler_llg_accept_impl,
|
||||
/* .apply = */ llama_sampler_llg_apply,
|
||||
/* .reset = */ llama_sampler_llg_reset,
|
||||
/* .clone = */ llama_sampler_llg_clone,
|
||||
/* .free = */ llama_sampler_llg_free,
|
||||
};
|
||||
|
||||
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
|
||||
uint32_t * output_tokens, size_t output_tokens_len) {
|
||||
const llama_vocab * vocab = (const llama_vocab *) user_data;
|
||||
int r = 0;
|
||||
try {
|
||||
r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
|
||||
true);
|
||||
} catch (const std::exception & e) {
|
||||
GGML_ABORT("llama_tokenize failed: %s\n", e.what());
|
||||
}
|
||||
if (r < 0) {
|
||||
return -r;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
|
||||
// TODO store the tokenizer in the vocab somehow
|
||||
static const llama_vocab * vocab_cache;
|
||||
static LlgTokenizer * tokenizer_cache;
|
||||
|
||||
if (vocab_cache == vocab) {
|
||||
return llg_clone_tokenizer(tokenizer_cache);
|
||||
}
|
||||
|
||||
auto tok_eos = llama_vocab_eot(vocab);
|
||||
if (tok_eos == LLAMA_TOKEN_NULL) {
|
||||
tok_eos = llama_vocab_eos(vocab);
|
||||
}
|
||||
|
||||
size_t vocab_size = llama_vocab_n_tokens(vocab);
|
||||
|
||||
auto token_lens = new uint32_t[vocab_size];
|
||||
// we typically have ~7 bytes per token; let's go on the safe side here
|
||||
auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
|
||||
auto token_bytes = new uint8_t[token_bytes_size];
|
||||
|
||||
size_t offset = 0;
|
||||
for (size_t i = 0; i < vocab_size; i++) {
|
||||
size_t max_token = 1024;
|
||||
if (token_bytes_size - offset < max_token) {
|
||||
GGML_ABORT("token_bytes buffer too small\n");
|
||||
}
|
||||
|
||||
llama_token token = i;
|
||||
auto dp = (char *) token_bytes + offset;
|
||||
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
|
||||
if (size < 0) {
|
||||
GGML_ABORT("llama_detokenize failed\n");
|
||||
}
|
||||
if (size == 0) {
|
||||
size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
|
||||
if (size < 0) {
|
||||
GGML_ABORT("llama_detokenize failed\n");
|
||||
}
|
||||
if (size != 0) {
|
||||
*dp = '\xff'; // special token prefix marker
|
||||
size += 1;
|
||||
}
|
||||
}
|
||||
|
||||
token_lens[i] = size;
|
||||
offset += size;
|
||||
}
|
||||
|
||||
LlgTokenizerInit tinit = {
|
||||
/* .vocab_size = */ (uint32_t) vocab_size,
|
||||
/* .tok_eos = */ (uint32_t) tok_eos,
|
||||
/* .token_lens = */ token_lens,
|
||||
/* .token_bytes = */ token_bytes,
|
||||
/* .tokenizer_json = */ nullptr,
|
||||
/* .tokenize_assumes_string = */ true,
|
||||
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
|
||||
/* .use_approximate_greedy_tokenize_fn = */ false,
|
||||
/* .tokenize_user_data = */ vocab,
|
||||
};
|
||||
|
||||
char error_buffer[1024];
|
||||
LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
|
||||
|
||||
delete[] token_bytes;
|
||||
delete[] token_lens;
|
||||
|
||||
if (tokenizer == nullptr) {
|
||||
LOG_ERR("llg tokenizer error: %s\n", error_buffer);
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
if (tokenizer_cache) {
|
||||
llg_free_tokenizer(tokenizer_cache);
|
||||
}
|
||||
vocab_cache = vocab;
|
||||
tokenizer_cache = tokenizer;
|
||||
|
||||
return llg_clone_tokenizer(tokenizer_cache);
|
||||
}
|
||||
|
||||
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
|
||||
const char * grammar_data) {
|
||||
auto * ctx = new llama_sampler_llg;
|
||||
|
||||
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
|
||||
auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
|
||||
*ctx = {
|
||||
/* .vocab = */ vocab,
|
||||
/* .grammar_kind = */ grammar_kind,
|
||||
/* .grammar_data = */ grammar_data,
|
||||
/* .tokenizer = */ tokenizer,
|
||||
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
|
||||
/* .llg_res = */ {},
|
||||
/* .has_llg_res = */ false,
|
||||
};
|
||||
} else {
|
||||
*ctx = {
|
||||
/* .vocab = */ vocab,
|
||||
/* .grammar_kind = */ {},
|
||||
/* .grammar_data = */ {},
|
||||
/* .tokenizer = */ nullptr,
|
||||
/* .grammar = */ nullptr,
|
||||
/* .llg_res = */ {},
|
||||
/* .has_llg_res = */ false,
|
||||
};
|
||||
}
|
||||
|
||||
return new llama_sampler{
|
||||
/* .iface = */ &llama_sampler_llg_i,
|
||||
/* .ctx = */ ctx,
|
||||
};
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
llama_grammar * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
|
||||
LOG("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
@@ -1,5 +1,4 @@
|
||||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
|
||||
549
common/minja/chat-template.hpp
Normal file
549
common/minja/chat-template.hpp
Normal file
@@ -0,0 +1,549 @@
|
||||
/*
|
||||
Copyright 2024 Google LLC
|
||||
|
||||
Use of this source code is governed by an MIT-style
|
||||
license that can be found in the LICENSE file or at
|
||||
https://opensource.org/licenses/MIT.
|
||||
*/
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "minja.hpp"
|
||||
|
||||
#include <chrono>
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <exception>
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <json.hpp>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
namespace minja {
|
||||
|
||||
struct chat_template_caps {
|
||||
bool supports_tools = false;
|
||||
bool supports_tool_calls = false;
|
||||
bool supports_tool_responses = false;
|
||||
bool supports_system_role = false;
|
||||
bool supports_parallel_tool_calls = false;
|
||||
bool supports_tool_call_id = false;
|
||||
// meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
|
||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||
bool requires_object_arguments = false;
|
||||
// CohereForAI/c4ai-command-r-plus simple variant
|
||||
bool requires_non_null_content = false;
|
||||
// MiniMaxAI/MiniMax-Text-01 special
|
||||
bool requires_typed_content = false;
|
||||
};
|
||||
|
||||
struct chat_template_inputs {
|
||||
nlohmann::ordered_json messages;
|
||||
nlohmann::ordered_json tools;
|
||||
bool add_generation_prompt = true;
|
||||
nlohmann::ordered_json extra_context;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
};
|
||||
|
||||
struct chat_template_options {
|
||||
bool apply_polyfills = true;
|
||||
bool use_bos_token = true;
|
||||
bool use_eos_token = true;
|
||||
bool define_strftime_now = true;
|
||||
|
||||
bool polyfill_tools = true;
|
||||
bool polyfill_tool_call_examples = true;
|
||||
bool polyfill_tool_calls = true;
|
||||
bool polyfill_tool_responses = true;
|
||||
bool polyfill_system_role = true;
|
||||
bool polyfill_object_arguments = true;
|
||||
bool polyfill_typed_content = true;
|
||||
};
|
||||
|
||||
class chat_template {
|
||||
|
||||
private:
|
||||
chat_template_caps caps_;
|
||||
std::string source_;
|
||||
std::string bos_token_;
|
||||
std::string eos_token_;
|
||||
std::shared_ptr<minja::TemplateNode> template_root_;
|
||||
std::string tool_call_example_;
|
||||
|
||||
std::string try_raw_render(
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
|
||||
{
|
||||
try {
|
||||
chat_template_inputs inputs;
|
||||
inputs.messages = messages;
|
||||
inputs.tools = tools;
|
||||
inputs.add_generation_prompt = add_generation_prompt;
|
||||
inputs.extra_context = extra_context;
|
||||
// Use fixed date for tests
|
||||
inputs.now = std::chrono::system_clock::from_time_t(0);
|
||||
|
||||
chat_template_options opts;
|
||||
opts.apply_polyfills = false;
|
||||
|
||||
auto prompt = apply(inputs, opts);
|
||||
// fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
|
||||
return prompt;
|
||||
} catch (const std::exception & e) {
|
||||
// fprintf(stderr, "try_raw_render error: %s\n", e.what());
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
|
||||
: source_(source), bos_token_(bos_token), eos_token_(eos_token)
|
||||
{
|
||||
template_root_ = minja::Parser::parse(source_, {
|
||||
/* .trim_blocks = */ true,
|
||||
/* .lstrip_blocks = */ true,
|
||||
/* .keep_trailing_newline = */ false,
|
||||
});
|
||||
|
||||
auto contains = [](const std::string & haystack, const std::string & needle) {
|
||||
return haystack.find(needle) != std::string::npos;
|
||||
};
|
||||
|
||||
const std::string user_needle = "<User Needle>";
|
||||
const std::string sys_needle = "<System Needle>";
|
||||
const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}};
|
||||
const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}};
|
||||
|
||||
caps_.requires_typed_content =
|
||||
!contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle)
|
||||
&& contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle);
|
||||
|
||||
const auto dummy_user_msg = caps_.requires_typed_content
|
||||
? dummy_typed_user_msg
|
||||
: dummy_str_user_msg;
|
||||
const json needle_system_msg = {
|
||||
{"role", "system"},
|
||||
{"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)},
|
||||
};
|
||||
|
||||
caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle);
|
||||
|
||||
auto out = try_raw_render(json::array({
|
||||
dummy_user_msg
|
||||
}), json::array({
|
||||
{
|
||||
{"name", "some_tool"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "some_tool"},
|
||||
{"description", "Some tool."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"arg", {
|
||||
{"type", "string"},
|
||||
{"description", "Some argument."},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({ "arg" })},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
}), false);
|
||||
caps_.supports_tools = contains(out, "some_tool");
|
||||
|
||||
const auto render_with_content = [&](const json & content) {
|
||||
const json assistant_msg {{"role", "assistant"}, {"content", content}};
|
||||
// Render two assistant messages as some templates like QwQ-32B are handling
|
||||
// the content differently depending on whether it's the last message or not
|
||||
// (to remove the <think> tag in all but the last message).
|
||||
return try_raw_render(json::array({dummy_user_msg, assistant_msg, dummy_user_msg, assistant_msg}), {}, false);
|
||||
};
|
||||
auto out_empty = render_with_content("");
|
||||
auto out_null = render_with_content(json());
|
||||
caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle);
|
||||
|
||||
json j_null;
|
||||
auto make_tool_calls_msg = [&](const json & tool_calls) {
|
||||
return json {
|
||||
{"role", "assistant"},
|
||||
{"content", caps_.requires_non_null_content? "" : j_null},
|
||||
{"tool_calls", tool_calls},
|
||||
};
|
||||
};
|
||||
auto make_tool_call = [](const std::string & tool_name, const json & arguments) {
|
||||
return json {
|
||||
{"id", "call_1___"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"arguments", arguments},
|
||||
{"name", tool_name},
|
||||
}},
|
||||
};
|
||||
};
|
||||
const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
|
||||
|
||||
// Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
|
||||
out = try_raw_render(json::array({
|
||||
dummy_user_msg,
|
||||
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
|
||||
}), {}, false);
|
||||
auto tool_call_renders_str_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
|
||||
out = try_raw_render(json::array({
|
||||
dummy_user_msg,
|
||||
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
|
||||
}), {}, false);
|
||||
auto tool_call_renders_obj_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
|
||||
|
||||
caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
|
||||
caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
|
||||
|
||||
if (caps_.supports_tool_calls) {
|
||||
auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump());
|
||||
auto tc1 = make_tool_call("test_tool1", dummy_args);
|
||||
auto tc2 = make_tool_call("test_tool2", dummy_args);
|
||||
auto out = try_raw_render(json::array({
|
||||
dummy_user_msg,
|
||||
make_tool_calls_msg(json::array({tc1, tc2})),
|
||||
}), {}, false);
|
||||
caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2");
|
||||
|
||||
out = try_raw_render(json::array({
|
||||
dummy_user_msg,
|
||||
make_tool_calls_msg(json::array({tc1})),
|
||||
{
|
||||
{"role", "tool"},
|
||||
{"name", "test_tool1"},
|
||||
{"content", "Some response!"},
|
||||
{"tool_call_id", "call_911_"},
|
||||
}
|
||||
}), {}, false);
|
||||
caps_.supports_tool_responses = contains(out, "Some response!");
|
||||
caps_.supports_tool_call_id = contains(out, "call_911_");
|
||||
}
|
||||
|
||||
try {
|
||||
if (!caps_.supports_tools) {
|
||||
const json user_msg {
|
||||
{"role", "user"},
|
||||
{"content", "Hey"},
|
||||
};
|
||||
const json args {
|
||||
{"arg1", "some_value"},
|
||||
};
|
||||
const json tool_call_msg {
|
||||
{"role", "assistant"},
|
||||
{"content", caps_.requires_non_null_content ? "" : j_null},
|
||||
{"tool_calls", json::array({
|
||||
{
|
||||
// TODO: detect if requires numerical id or fixed length == 6 like Nemo
|
||||
{"id", "call_1___"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "tool_name"},
|
||||
{"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))},
|
||||
}},
|
||||
},
|
||||
})},
|
||||
};
|
||||
std::string prefix, full;
|
||||
{
|
||||
chat_template_inputs inputs;
|
||||
inputs.messages = json::array({user_msg});
|
||||
inputs.add_generation_prompt = true;
|
||||
prefix = apply(inputs);
|
||||
}
|
||||
{
|
||||
chat_template_inputs inputs;
|
||||
inputs.messages = json::array({user_msg, tool_call_msg});
|
||||
inputs.add_generation_prompt = false;
|
||||
full = apply(inputs);
|
||||
}
|
||||
auto eos_pos_last = full.rfind(eos_token_);
|
||||
if (eos_pos_last == prefix.size() - eos_token_.size() ||
|
||||
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
|
||||
full = full.substr(0, eos_pos_last);
|
||||
}
|
||||
size_t common_prefix_length = 0;
|
||||
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
|
||||
if (prefix[i] != full[i]) {
|
||||
break;
|
||||
}
|
||||
if (prefix[i] == '<') {
|
||||
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
|
||||
// but it removes thinking tags for past messages.
|
||||
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
|
||||
continue;
|
||||
}
|
||||
common_prefix_length = i + 1;
|
||||
}
|
||||
auto example = full.substr(common_prefix_length);
|
||||
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
|
||||
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
|
||||
} else {
|
||||
tool_call_example_ = example;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
const std::string & source() const { return source_; }
|
||||
const std::string & bos_token() const { return bos_token_; }
|
||||
const std::string & eos_token() const { return eos_token_; }
|
||||
const chat_template_caps & original_caps() const { return caps_; }
|
||||
|
||||
// Deprecated, please use the form with chat_template_inputs and chat_template_options
|
||||
std::string apply(
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
|
||||
bool apply_polyfills = true)
|
||||
{
|
||||
fprintf(stderr, "[%s] Deprecated!\n", __func__);
|
||||
chat_template_inputs inputs;
|
||||
inputs.messages = messages;
|
||||
inputs.tools = tools;
|
||||
inputs.add_generation_prompt = add_generation_prompt;
|
||||
inputs.extra_context = extra_context;
|
||||
inputs.now = std::chrono::system_clock::now();
|
||||
|
||||
chat_template_options opts;
|
||||
opts.apply_polyfills = apply_polyfills;
|
||||
|
||||
return apply(inputs, opts);
|
||||
}
|
||||
|
||||
std::string apply(
|
||||
const chat_template_inputs & inputs,
|
||||
const chat_template_options & opts = chat_template_options()) const
|
||||
{
|
||||
json actual_messages;
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_tool_calls = false;
|
||||
auto has_tool_responses = false;
|
||||
auto has_string_content = false;
|
||||
for (const auto & message : inputs.messages) {
|
||||
if (message.contains("tool_calls") && !message["tool_calls"].is_null()) {
|
||||
has_tool_calls = true;
|
||||
}
|
||||
if (message.contains("role") && message["role"] == "tool") {
|
||||
has_tool_responses = true;
|
||||
}
|
||||
if (message.contains("content") && message["content"].is_string()) {
|
||||
has_string_content = true;
|
||||
}
|
||||
}
|
||||
|
||||
auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role;
|
||||
auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools;
|
||||
auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples;
|
||||
auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls;
|
||||
auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses;
|
||||
auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments;
|
||||
auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content;
|
||||
|
||||
auto needs_polyfills = opts.apply_polyfills && (false
|
||||
|| polyfill_system_role
|
||||
|| polyfill_tools
|
||||
|| polyfill_tool_calls
|
||||
|| polyfill_tool_responses
|
||||
|| polyfill_object_arguments
|
||||
|| polyfill_typed_content
|
||||
);
|
||||
|
||||
if (needs_polyfills) {
|
||||
actual_messages = json::array();
|
||||
|
||||
auto add_message = [&](const json & msg) {
|
||||
if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
|
||||
actual_messages.push_back({
|
||||
{"role", msg.at("role")},
|
||||
{"content", {{
|
||||
{"type", "text"},
|
||||
{"text", msg.at("content")},
|
||||
}}},
|
||||
});
|
||||
} else {
|
||||
actual_messages.push_back(msg);
|
||||
}
|
||||
};
|
||||
|
||||
std::string pending_system;
|
||||
auto flush_sys = [&]() {
|
||||
if (!pending_system.empty()) {
|
||||
add_message({
|
||||
{"role", "user"},
|
||||
{"content", pending_system},
|
||||
});
|
||||
pending_system.clear();
|
||||
}
|
||||
};
|
||||
|
||||
json adjusted_messages;
|
||||
if (polyfill_tools) {
|
||||
adjusted_messages = add_system(inputs.messages,
|
||||
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
|
||||
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
|
||||
} else {
|
||||
adjusted_messages = inputs.messages;
|
||||
}
|
||||
|
||||
for (const auto & message_ : adjusted_messages) {
|
||||
auto message = message_;
|
||||
if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) {
|
||||
throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump());
|
||||
}
|
||||
std::string role = message.at("role");
|
||||
|
||||
if (message.contains("tool_calls")) {
|
||||
if (polyfill_object_arguments || polyfill_tool_calls) {
|
||||
for (auto & tool_call : message.at("tool_calls")) {
|
||||
if (tool_call["type"] == "function") {
|
||||
auto & function = tool_call.at("function");
|
||||
auto & arguments = function.at("arguments");
|
||||
if (arguments.is_string()) {
|
||||
try {
|
||||
arguments = json::parse(arguments.get<std::string>());
|
||||
} catch (const std::exception & ecvt) {
|
||||
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (polyfill_tool_calls) {
|
||||
auto tool_calls = json::array();
|
||||
for (const auto & tool_call : message.at("tool_calls")) {
|
||||
if (tool_call.at("type") != "function") {
|
||||
continue;
|
||||
}
|
||||
const auto & function = tool_call.at("function");
|
||||
auto tc = json {
|
||||
{"name", function.at("name")},
|
||||
{"arguments", function.at("arguments")},
|
||||
};
|
||||
if (tool_call.contains("id")) {
|
||||
tc["id"] = tool_call["id"];
|
||||
}
|
||||
tool_calls.push_back(tc);
|
||||
}
|
||||
auto obj = json {
|
||||
{"tool_calls", tool_calls},
|
||||
};
|
||||
if (message.contains("content")) {
|
||||
auto content = message.at("content");
|
||||
if (!content.is_null() && !content.empty()) {
|
||||
obj["content"] = content;
|
||||
}
|
||||
}
|
||||
message["content"] = obj.dump(2);
|
||||
message.erase("tool_calls");
|
||||
}
|
||||
}
|
||||
if (polyfill_tool_responses && role == "tool") {
|
||||
message["role"] = "user";
|
||||
auto obj = json {
|
||||
{"tool_response", json::object()},
|
||||
};
|
||||
if (message.contains("name")) {
|
||||
obj["tool_response"]["tool"] = message.at("name");
|
||||
}
|
||||
obj["tool_response"]["content"] = message.at("content");
|
||||
if (message.contains("tool_call_id")) {
|
||||
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
|
||||
}
|
||||
message["content"] = obj.dump(2);
|
||||
message.erase("name");
|
||||
}
|
||||
|
||||
if (!message["content"].is_null() && polyfill_system_role) {
|
||||
std::string content = message.at("content");
|
||||
if (role == "system") {
|
||||
if (!pending_system.empty()) pending_system += "\n";
|
||||
pending_system += content;
|
||||
continue;
|
||||
} else {
|
||||
if (role == "user") {
|
||||
if (!pending_system.empty()) {
|
||||
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
|
||||
pending_system.clear();
|
||||
}
|
||||
} else {
|
||||
flush_sys();
|
||||
}
|
||||
}
|
||||
}
|
||||
add_message(message);
|
||||
}
|
||||
flush_sys();
|
||||
} else {
|
||||
actual_messages = inputs.messages;
|
||||
}
|
||||
|
||||
auto context = minja::Context::make(json({
|
||||
{"messages", actual_messages},
|
||||
{"add_generation_prompt", inputs.add_generation_prompt},
|
||||
}));
|
||||
context->set("bos_token", opts.use_bos_token ? bos_token_ : "");
|
||||
context->set("eos_token", opts.use_eos_token ? eos_token_ : "");
|
||||
if (opts.define_strftime_now) {
|
||||
auto now = inputs.now;
|
||||
context->set("strftime_now", Value::callable([now](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) {
|
||||
args.expectArgs("strftime_now", {1, 1}, {0, 0});
|
||||
auto format = args.args[0].get<std::string>();
|
||||
|
||||
auto time = std::chrono::system_clock::to_time_t(now);
|
||||
auto local_time = *std::localtime(&time);
|
||||
std::ostringstream ss;
|
||||
ss << std::put_time(&local_time, format.c_str());
|
||||
return ss.str();
|
||||
}));
|
||||
}
|
||||
if (!inputs.tools.is_null()) {
|
||||
context->set("tools", minja::Value(inputs.tools));
|
||||
}
|
||||
if (!inputs.extra_context.is_null()) {
|
||||
for (auto & kv : inputs.extra_context.items()) {
|
||||
context->set(kv.key(), minja::Value(kv.value()));
|
||||
}
|
||||
}
|
||||
|
||||
auto ret = template_root_->render(context);
|
||||
// fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str());
|
||||
// fprintf(stderr, "apply: %s\n\n", ret.c_str());
|
||||
return ret;
|
||||
}
|
||||
|
||||
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
|
||||
json messages_with_system = messages;
|
||||
|
||||
if (!messages_with_system.empty() && messages_with_system[0].at("role") == "system") {
|
||||
std::string existing_system = messages_with_system.at(0).at("content");
|
||||
messages_with_system[0] = json {
|
||||
{"role", "system"},
|
||||
{"content", existing_system + "\n\n" + system_prompt},
|
||||
};
|
||||
} else {
|
||||
messages_with_system.insert(messages_with_system.begin(), json {
|
||||
{"role", "system"},
|
||||
{"content", system_prompt},
|
||||
});
|
||||
}
|
||||
return messages_with_system;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace minja
|
||||
@@ -8,14 +8,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cmath>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <json.hpp>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
@@ -1233,7 +1246,7 @@ public:
|
||||
}
|
||||
return result;
|
||||
|
||||
} else if (target_value.is_array()) {
|
||||
} else if (target_value.is_array()) {
|
||||
auto result = Value::array();
|
||||
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
|
||||
result.push_back(target_value.at(i));
|
||||
@@ -1278,6 +1291,12 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
static bool in(const Value & value, const Value & container) {
|
||||
return (((container.is_array() || container.is_object()) && container.contains(value)) ||
|
||||
(value.is_string() && container.is_string() &&
|
||||
container.to_str().find(value.to_str()) != std::string::npos));
|
||||
}
|
||||
|
||||
class BinaryOpExpr : public Expression {
|
||||
public:
|
||||
enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot };
|
||||
@@ -1342,13 +1361,8 @@ public:
|
||||
case Op::Gt: return l > r;
|
||||
case Op::Le: return l <= r;
|
||||
case Op::Ge: return l >= r;
|
||||
case Op::In: return (((r.is_array() || r.is_object()) && r.contains(l)) ||
|
||||
(l.is_string() && r.is_string() &&
|
||||
r.to_str().find(l.to_str()) != std::string::npos));
|
||||
case Op::NotIn:
|
||||
return !(((r.is_array() || r.is_object()) && r.contains(l)) ||
|
||||
(l.is_string() && r.is_string() &&
|
||||
r.to_str().find(l.to_str()) != std::string::npos));
|
||||
case Op::In: return in(l, r);
|
||||
case Op::NotIn: return !in(l, r);
|
||||
default: break;
|
||||
}
|
||||
throw std::runtime_error("Unknown binary operator");
|
||||
@@ -1487,6 +1501,13 @@ public:
|
||||
} else if (method->get_name() == "pop") {
|
||||
vargs.expectArgs("pop method", {1, 1}, {0, 0});
|
||||
return obj.pop(vargs.args[0]);
|
||||
} else if (method->get_name() == "keys") {
|
||||
vargs.expectArgs("keys method", {0, 0}, {0, 0});
|
||||
auto result = Value::array();
|
||||
for (const auto& key : obj.keys()) {
|
||||
result.push_back(Value(key));
|
||||
}
|
||||
return result;
|
||||
} else if (method->get_name() == "get") {
|
||||
vargs.expectArgs("get method", {1, 2}, {0, 0});
|
||||
auto key = vargs.args[0];
|
||||
@@ -1528,6 +1549,16 @@ public:
|
||||
} else if (method->get_name() == "capitalize") {
|
||||
vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
|
||||
return Value(capitalize(str));
|
||||
} else if (method->get_name() == "upper") {
|
||||
vargs.expectArgs("upper method", {0, 0}, {0, 0});
|
||||
auto result = str;
|
||||
std::transform(result.begin(), result.end(), result.begin(), ::toupper);
|
||||
return Value(result);
|
||||
} else if (method->get_name() == "lower") {
|
||||
vargs.expectArgs("lower method", {0, 0}, {0, 0});
|
||||
auto result = str;
|
||||
std::transform(result.begin(), result.end(), result.begin(), ::tolower);
|
||||
return Value(result);
|
||||
} else if (method->get_name() == "endswith") {
|
||||
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
|
||||
auto suffix = vargs.args[0].get<std::string>();
|
||||
@@ -1544,20 +1575,6 @@ public:
|
||||
else res[i] = std::tolower(res[i]);
|
||||
}
|
||||
return res;
|
||||
|
||||
} else if (method->get_name() == "replace") {
|
||||
vargs.expectArgs("replace method", {2, 3}, {0, 0});
|
||||
auto before = vargs.args[0].get<std::string>();
|
||||
auto after = vargs.args[1].get<std::string>();
|
||||
auto count = vargs.args.size() == 3 ? vargs.args[2].get<int64_t>()
|
||||
: str.length();
|
||||
size_t start_pos = 0;
|
||||
while ((start_pos = str.find(before, start_pos)) != std::string::npos &&
|
||||
count-- > 0) {
|
||||
str.replace(start_pos, before.length(), after);
|
||||
start_pos += after.length();
|
||||
}
|
||||
return str;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Unknown method: " + method->get_name());
|
||||
@@ -2117,38 +2134,9 @@ private:
|
||||
std::shared_ptr<Expression> start, end, step;
|
||||
bool has_first_colon = false, has_second_colon = false;
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if (!peekSymbols({ ":" })) {
|
||||
start = parseExpression();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
if (!consumeToken(":").empty()) {
|
||||
has_first_colon = true;
|
||||
@@ -2162,8 +2150,8 @@ private:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ((has_first_colon || has_second_colon)) {
|
||||
|
||||
if ((has_first_colon || has_second_colon) && (start || end || step)) {
|
||||
index = std::make_shared<SliceExpr>(slice_loc, std::move(start), std::move(end), std::move(step));
|
||||
} else {
|
||||
index = std::move(start);
|
||||
@@ -2663,17 +2651,13 @@ inline std::shared_ptr<Context> Context::builtins() {
|
||||
auto items = Value::array();
|
||||
if (args.contains("object")) {
|
||||
auto & obj = args.at("object");
|
||||
if (obj.is_string()) {
|
||||
auto json_obj = json::parse(obj.get<std::string>());
|
||||
for (const auto & kv : json_obj.items()) {
|
||||
items.push_back(Value::array({kv.key(), kv.value()}));
|
||||
if (!obj.is_object()) {
|
||||
throw std::runtime_error("Can only get item pairs from a mapping");
|
||||
}
|
||||
} else if (!obj.is_null()) {
|
||||
for (auto & key : obj.keys()) {
|
||||
items.push_back(Value::array({key, obj.at(key)}));
|
||||
}
|
||||
}
|
||||
}
|
||||
return items;
|
||||
}));
|
||||
globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr<Context> &, Value & args) {
|
||||
@@ -2686,14 +2670,6 @@ inline std::shared_ptr<Context> Context::builtins() {
|
||||
auto & text = args.at("text");
|
||||
return text.is_null() ? text : Value(strip(text.get<std::string>()));
|
||||
}));
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
auto char_transform_function = [](const std::string & name, const std::function<char(char)> & fn) {
|
||||
return simple_function(name, { "text" }, [=](const std::shared_ptr<Context> &, Value & args) {
|
||||
auto text = args.at("text");
|
||||
@@ -2807,6 +2783,9 @@ inline std::shared_ptr<Context> Context::builtins() {
|
||||
if (!items.is_array()) throw std::runtime_error("object is not iterable");
|
||||
return items;
|
||||
}));
|
||||
globals.set("in", simple_function("in", { "item", "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
|
||||
return in(args.at("item"), args.at("items"));
|
||||
}));
|
||||
globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
|
||||
auto & items = args.at("items");
|
||||
if (!items.is_array()) throw std::runtime_error("object is not iterable");
|
||||
@@ -2846,16 +2825,10 @@ inline std::shared_ptr<Context> Context::builtins() {
|
||||
if (filter_fn.is_null()) {
|
||||
throw std::runtime_error("Undefined filter: " + args.args[1].dump());
|
||||
}
|
||||
|
||||
|
||||
auto filter_args = Value::array();
|
||||
for (size_t i = 2, n = args.args.size(); i < n; i++) {
|
||||
|
||||
|
||||
filter_args.push_back(args.args[i]);
|
||||
|
||||
|
||||
|
||||
}
|
||||
auto filter = make_filter(filter_fn, filter_args);
|
||||
|
||||
@@ -2942,8 +2915,6 @@ inline std::shared_ptr<Context> Context::builtins() {
|
||||
}
|
||||
test_args.kwargs = args.kwargs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
auto res = Value::array();
|
||||
for (size_t i = 0, n = items.size(); i < n; i++) {
|
||||
@@ -2957,10 +2928,7 @@ inline std::shared_ptr<Context> Context::builtins() {
|
||||
} else {
|
||||
res.push_back(attr);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
return res;
|
||||
});
|
||||
};
|
||||
@@ -2978,7 +2946,6 @@ inline std::shared_ptr<Context> Context::builtins() {
|
||||
auto v = arg.get<int64_t>();
|
||||
startEndStep[i] = v;
|
||||
param_set[i] = true;
|
||||
|
||||
}
|
||||
}
|
||||
for (auto & [name, value] : args.kwargs) {
|
||||
@@ -118,7 +118,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
||||
if (it == end) {
|
||||
throw std::runtime_error("Unmatched '{' in pattern");
|
||||
}
|
||||
auto parts = string_split(std::string(start, it), ',');
|
||||
auto parts = string_split(std::string(start, it), ",");
|
||||
++it;
|
||||
if (parts.size() > 2) {
|
||||
throw std::runtime_error("Invalid repetition range in pattern");
|
||||
|
||||
@@ -9,8 +9,23 @@ enum common_regex_match_type {
|
||||
COMMON_REGEX_MATCH_TYPE_FULL,
|
||||
};
|
||||
|
||||
// Include full definition of common_string_range
|
||||
#include "chat.h"
|
||||
struct common_string_range {
|
||||
size_t begin;
|
||||
size_t end;
|
||||
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
|
||||
if (begin > end) {
|
||||
throw std::runtime_error("Invalid range");
|
||||
}
|
||||
}
|
||||
// prevent default ctor
|
||||
common_string_range() = delete;
|
||||
bool empty() const {
|
||||
return begin == end;
|
||||
}
|
||||
bool operator==(const common_string_range & other) const {
|
||||
return begin == other.begin && end == other.end;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_regex_match {
|
||||
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
#define LLAMA_API_INTERNAL
|
||||
#include "sampling.h"
|
||||
#include "llama-vocab.h"
|
||||
#include "common.h"
|
||||
#include <random>
|
||||
#include "json.hpp"
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params) {
|
||||
struct llama_sampling_context * result = new llama_sampling_context();
|
||||
@@ -9,10 +12,68 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo
|
||||
result->params = params;
|
||||
result->grammar = nullptr;
|
||||
|
||||
|
||||
struct llama_grammar* grmr;
|
||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||
#else
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
}
|
||||
else {
|
||||
|
||||
std::vector<std::string> trigger_patterns;
|
||||
std::vector<std::string> patterns_anywhere;
|
||||
std::vector<llama_token> trigger_tokens;
|
||||
for (const auto& trigger : params.grammar_triggers) {
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto& word = trigger.value;
|
||||
patterns_anywhere.push_back(regex_escape(word));
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
{
|
||||
patterns_anywhere.push_back(trigger.value);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||
{
|
||||
trigger_patterns.push_back(trigger.value);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||
{
|
||||
const auto token = trigger.token;
|
||||
trigger_tokens.push_back(token);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown trigger type");
|
||||
}
|
||||
}
|
||||
|
||||
if (!patterns_anywhere.empty()) {
|
||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
|
||||
std::vector<const char*> trigger_patterns_c;
|
||||
trigger_patterns_c.reserve(trigger_patterns.size());
|
||||
for (const auto& regex : trigger_patterns) {
|
||||
trigger_patterns_c.push_back(regex.c_str());
|
||||
}
|
||||
grmr = params.grammar_lazy
|
||||
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size())
|
||||
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
|
||||
// if there is a grammar, parse it
|
||||
if (!params.grammar.empty()) {
|
||||
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
|
||||
if (result->parsed_grammar.success) {
|
||||
// will be empty (default) if there are parse errors
|
||||
if (result->parsed_grammar.rules.empty()) {
|
||||
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
||||
@@ -26,21 +87,15 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
|
||||
|
||||
struct llama_grammar * grammar = llama_grammar_init(
|
||||
grammar_rules.data(),
|
||||
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
||||
if (grammar == nullptr) {
|
||||
if (grmr == nullptr) {
|
||||
throw std::runtime_error("Failed to initialize llama_grammar");
|
||||
}
|
||||
result->grammar = grammar;
|
||||
}
|
||||
}
|
||||
result->prev.resize(params.n_prev);
|
||||
|
||||
result->n_valid = 0;
|
||||
|
||||
}
|
||||
result->grammar = grmr;
|
||||
// init DRY
|
||||
for (const auto& cnstr : params.samplers_sequence)
|
||||
{
|
||||
@@ -75,27 +130,71 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
void llama_sampling_reset(llama_sampling_context * ctx) {
|
||||
void llama_sampling_reset(const struct llama_vocab* vocab, llama_sampling_context * ctx) {
|
||||
|
||||
if (ctx->grammar != NULL) {
|
||||
llama_grammar_free(ctx->grammar);
|
||||
ctx->grammar = NULL;
|
||||
}
|
||||
|
||||
if (!ctx->parsed_grammar.rules.empty()) {
|
||||
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
|
||||
|
||||
struct llama_grammar * grammar = llama_grammar_init(
|
||||
grammar_rules.data(),
|
||||
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
|
||||
if (grammar == nullptr) {
|
||||
throw std::runtime_error("Failed to initialize llama_grammar");
|
||||
}
|
||||
ctx->grammar = grammar;
|
||||
struct llama_grammar* grmr;
|
||||
auto params = ctx->params;
|
||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||
#else
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
}
|
||||
else {
|
||||
std::vector<std::string> trigger_patterns;
|
||||
std::vector<std::string> patterns_anywhere;
|
||||
std::vector<llama_token> trigger_tokens;
|
||||
for (const auto& trigger : params.grammar_triggers) {
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto& word = trigger.value;
|
||||
patterns_anywhere.push_back(regex_escape(word));
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
{
|
||||
patterns_anywhere.push_back(trigger.value);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||
{
|
||||
trigger_patterns.push_back(trigger.value);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||
{
|
||||
const auto token = trigger.token;
|
||||
trigger_tokens.push_back(token);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown trigger type");
|
||||
}
|
||||
}
|
||||
if (!patterns_anywhere.empty()) {
|
||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
|
||||
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
||||
ctx->cur.clear();
|
||||
ctx->n_valid = 0;
|
||||
std::vector<const char*> trigger_patterns_c;
|
||||
trigger_patterns_c.reserve(trigger_patterns.size());
|
||||
for (const auto& regex : trigger_patterns) {
|
||||
trigger_patterns_c.push_back(regex.c_str());
|
||||
}
|
||||
|
||||
grmr = params.grammar_lazy
|
||||
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size())
|
||||
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
}
|
||||
|
||||
ctx->grammar = grmr;
|
||||
llama_sampler_dry_reset(ctx->smpl);
|
||||
}
|
||||
|
||||
@@ -498,7 +597,10 @@ void llama_sampling_accept(
|
||||
struct llama_context * ctx_main,
|
||||
llama_token id,
|
||||
bool apply_grammar) {
|
||||
if (ctx_sampling->prev.size() > 0) {
|
||||
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
|
||||
|
||||
}
|
||||
ctx_sampling->prev.push_back(id);
|
||||
|
||||
if (ctx_sampling->grammar != NULL && apply_grammar) {
|
||||
@@ -552,3 +654,29 @@ std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_samplin
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
template <>
|
||||
json common_grammar_trigger::to_json() const {
|
||||
json out{
|
||||
{"type", (int)type},
|
||||
{"value", value},
|
||||
};
|
||||
if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
|
||||
out["token"] = (int)token;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <>
|
||||
common_grammar_trigger common_grammar_trigger::from_json(const json& in) {
|
||||
common_grammar_trigger out;
|
||||
out.type = (common_grammar_trigger_type)in.at("type").get<int>();
|
||||
out.value = in.at("value").get<std::string>();
|
||||
if (out.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
|
||||
out.token = (llama_token)in.at("token").get<int>();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <set>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
@@ -22,6 +21,23 @@ enum class llama_sampler_type : char {
|
||||
TEMPERATURE = 't'
|
||||
};
|
||||
|
||||
enum common_grammar_trigger_type {
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
|
||||
};
|
||||
|
||||
struct common_grammar_trigger {
|
||||
common_grammar_trigger_type type;
|
||||
std::string value;
|
||||
llama_token token = LLAMA_TOKEN_NULL;
|
||||
|
||||
// T can only be nlohmann::ordered_json
|
||||
template <class T> T to_json() const;
|
||||
template <class T> static common_grammar_trigger from_json(const T& in);
|
||||
};
|
||||
|
||||
// sampling parameters
|
||||
typedef struct llama_sampling_params {
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
@@ -67,8 +83,11 @@ typedef struct llama_sampling_params {
|
||||
llama_sampler_type::TEMPERATURE
|
||||
};
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
bool grammar_lazy = false;
|
||||
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
|
||||
std::set<llama_token> preserved_tokens;
|
||||
// Classifier-Free Guidance
|
||||
// https://arxiv.org/abs/2306.17806
|
||||
std::string cfg_negative_prompt; // string to help guidance
|
||||
@@ -106,7 +125,7 @@ struct llama_sampling_context {
|
||||
std::mt19937 rng;
|
||||
};
|
||||
|
||||
#include "common.h"
|
||||
|
||||
|
||||
// Create a new sampling context instance.
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params);
|
||||
@@ -116,7 +135,7 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
|
||||
// Reset the sampler context
|
||||
// - clear prev tokens
|
||||
// - reset grammar
|
||||
void llama_sampling_reset(llama_sampling_context * ctx);
|
||||
void llama_sampling_reset(const struct llama_vocab* vocab, llama_sampling_context * ctx);
|
||||
|
||||
// Set the sampler seed
|
||||
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
|
||||
@@ -186,3 +205,6 @@ llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_con
|
||||
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft);
|
||||
|
||||
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft);
|
||||
|
||||
llama_grammar* llama_sampler_init_llg(const llama_vocab* vocab,
|
||||
const char* grammar_kind, const char* grammar_data);
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "common.h"
|
||||
#include "sampling.h"
|
||||
#include "llama-impl.h"
|
||||
|
||||
#include "llama-vocab.h"
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
@@ -302,7 +302,7 @@ std::vector<llama_token> llama_speculative_gen_draft(
|
||||
|
||||
llama_decode(ctx_dft, batch);
|
||||
|
||||
llama_sampling_reset(smpl);
|
||||
llama_sampling_reset(llama_get_vocab(ctx_dft), smpl);
|
||||
|
||||
// sample n_draft tokens from the draft model
|
||||
for (int i = 0; i < params.n_draft; ++i) {
|
||||
|
||||
Reference in New Issue
Block a user