Tool calls support from mainline (#723)

* Tool calls support from mainline

* update cmake

* revert api for /completions

* Fix broken thinking process for gpt-oss

* add missing args and fix webui bugs

* add missing args and fix webui bugs2

* Fix reasoning format error

* add usage

* change default post_sampling_probs to true

* add back generated_text

* Remove server endpoints tests

* add log

* Chat fixes

* Remove logs

* webui: revert extra handling of thinking process

---------

Co-authored-by: firecoperana <firecoperana>
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
firecoperana
2025-09-01 00:38:49 -05:00
committed by GitHub
parent 8de297b795
commit d7882c3cf8
87 changed files with 13581 additions and 2224 deletions

View File

@@ -52,17 +52,13 @@ set(TARGET common)
add_library(${TARGET} STATIC
base64.hpp
chat-template.hpp
common.h
chat.cpp
chat.h
chat-parser.cpp
chat-parser.h
common.cpp
chat.h
chat.cpp
chat-parser.h
chat-parser.cpp
json-partial.h
json-partial.cpp
regex-partial.h
regex-partial.cpp
sampling.h
sampling.cpp
console.h
@@ -70,13 +66,19 @@ add_library(${TARGET} STATIC
grammar-parser.h
grammar-parser.cpp
json.hpp
json-partial.h
json-partial.cpp
llguidance.cpp
json-schema-to-grammar.cpp
train.h
train.cpp
minja.hpp
minja/chat-template.hpp
minja/minja.hpp
ngram-cache.h
ngram-cache.cpp
speculative.cpp
regex-partial.cpp
regex-partial.h
)
if (BUILD_SHARED_LIBS)
@@ -94,6 +96,33 @@ if (LLAMA_CURL)
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY})
endif ()
if (LLAMA_LLGUIDANCE)
include(ExternalProject)
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
ExternalProject_Add(llguidance_ext
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
# v0.6.12:
GIT_TAG ced1c9023d47ec194fa977932d35ce65c2ebfc09
PREFIX ${CMAKE_BINARY_DIR}/llguidance
SOURCE_DIR ${LLGUIDANCE_SRC}
BUILD_IN_SOURCE TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND cargo build --release
INSTALL_COMMAND ""
BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/libllguidance.a ${LLGUIDANCE_PATH}/llguidance.h
UPDATE_COMMAND ""
)
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_LLGUIDANCE)
add_library(llguidance STATIC IMPORTED)
set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/libllguidance.a)
add_dependencies(llguidance llguidance_ext)
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance)
endif ()
target_include_directories(${TARGET} PUBLIC .)
target_compile_features (${TARGET} PUBLIC cxx_std_11)
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)

View File

@@ -1,54 +1,69 @@
// Chat parser implementation
#include "chat-parser.h"
#include "../examples/server/parsers/kimi_k2_parser.hpp"
#include "json.hpp"
#include "common.h"
#include "log.h"
#include "regex-partial.h"
#include <optional>
#include <stdexcept>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
: input_(input), is_partial_(is_partial), syntax_(syntax) {
// Initialize result with default role
: input_(input), is_partial_(is_partial), syntax_(syntax)
{
result_.role = "assistant";
while (true) {
std::string id = std::to_string(std::rand());
if (input.find(id) == std::string::npos) {
healing_marker_ = id;
break;
}
}
}
std::string common_chat_msg_parser::str(const common_string_range & rng) const {
if (rng.begin > input_.size() || rng.end > input_.size()) {
throw std::runtime_error("Range out of bounds");
}
GGML_ASSERT(rng.begin <= rng.end);
return input_.substr(rng.begin, rng.end - rng.begin);
}
void common_chat_msg_parser::add_content(const std::string & content) {
void common_chat_msg_parser::add_content(const std::string &content) {
result_.content += content;
}
void common_chat_msg_parser::add_reasoning_content(const std::string & reasoning_content) {
void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) {
result_.reasoning_content += reasoning_content;
}
void common_chat_msg_parser::add_tool_call(const common_chat_tool_call & tool_call) {
result_.tool_calls.push_back(tool_call);
}
bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) {
if (name.empty()) {
return false;
}
common_chat_tool_call tool_call;
tool_call.name = name;
tool_call.arguments = arguments;
tool_call.id = id;
// LOG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str());
result_.tool_calls.emplace_back(tool_call);
return true;
}
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
std::string arguments = "";
if (tool_call.contains("arguments")) {
if (tool_call.at("arguments").is_object()) {
arguments = tool_call.at("arguments").dump();
} else {
arguments = tool_call.at("arguments");
}
}
return add_tool_call(name, id, arguments);
}
@@ -60,25 +75,65 @@ bool common_chat_msg_parser::add_tool_calls(const json & arr) {
}
return true;
}
void common_chat_msg_parser::clear_tools() {
result_.tool_calls.clear();
void common_chat_msg_parser::finish() {
if (!is_partial_ && pos_ != input_.size()) {
throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_));
}
}
std::string common_chat_msg_parser::consume_rest() {
auto rest = input_.substr(pos_);
pos_ = input_.size();
return rest;
bool common_chat_msg_parser::consume_spaces() {
const auto length = input_.size();
auto consumed = false;
while (pos_ < length && std::isspace(input_[pos_])) {
++pos_;
consumed = true;
}
return consumed;
}
bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
if (pos_ + literal.size() <= input_.size()) {
if (input_.substr(pos_, literal.size()) == literal) {
pos_ += literal.size();
return true;
auto pos = pos_;
for (auto i = 0u; i < literal.size(); ++i) {
if (pos >= input_.size()) {
return false;
}
if (input_[pos] != literal[i]) {
return false;
}
++pos;
}
pos_ = pos;
return true;
}
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal(const std::string & literal) {
auto idx = input_.find(literal, pos_);
if (idx != std::string::npos) {
find_regex_result res;
res.prelude = input_.substr(pos_, idx - pos_);
auto end = idx + literal.size();
res.groups.emplace_back(common_string_range{idx, end});
move_to(end);
return res;
}
if (is_partial_) {
idx = string_find_partial_stop(input_, literal);
if (idx != std::string::npos && idx >= pos_) {
find_regex_result res;
res.prelude = input_.substr(pos_, idx - pos_);
auto end = input_.size();
res.groups.emplace_back(common_string_range{idx, end});
move_to(end);
return res;
}
}
return false;
return std::nullopt;
}
void common_chat_msg_parser::consume_literal(const std::string & literal) {
if (!try_consume_literal(literal)) {
throw common_chat_msg_partial_exception(literal);
}
}
bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
@@ -97,7 +152,6 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think
add_reasoning_content(stripped_reasoning);
}
};
if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
if (auto res = try_find_literal(end_think)) {
@@ -109,198 +163,73 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think
if (!rest.empty()) {
handle_reasoning(rest, /* closed */ !is_partial());
}
// Allow unclosed thinking tags for now (following original llama.cpp)
// Allow unclosed thinking tags, for now (https://github.com/ggml-org/llama.cpp/issues/13812, https://github.com/ggml-org/llama.cpp/issues/13877)
// if (!syntax_.thinking_forced_open) {
// throw common_chat_msg_partial_exception(end_think);
// }
return true;
}
}
return false;
}
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal_legacy(const std::string & literal) {
auto idx = input_.find(literal, pos_);
if (idx != std::string::npos) {
find_regex_result res;
res.prelude = input_.substr(pos_, idx - pos_);
auto end = idx + literal.size();
res.groups.emplace_back(common_string_range{idx, end});
move_to(end);
return res;
}
if (is_partial_) {
idx = string_find_partial_stop(input_, literal);
if (idx != std::string::npos && idx >= pos_) {
find_regex_result res;
res.prelude = input_.substr(pos_, idx - pos_);
auto end = input_.size();
res.groups.emplace_back(common_string_range{idx, end});
move_to(end);
return res;
}
}
return std::nullopt;
}
void common_chat_msg_parser::parse() {
switch (syntax_.format) {
case COMMON_CHAT_FORMAT_KIMI_K2:
parse_kimi_k2_format();
break;
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
parse_deepseek_r1_format();
break;
case COMMON_CHAT_FORMAT_GENERIC:
parse_generic_format();
break;
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
add_content(consume_rest());
break;
default:
// Fallback to content-only for now
add_content(consume_rest());
break;
}
}
void common_chat_msg_parser::parse_kimi_k2_format() {
json tool_calls_json = kimi_k2::parse_tool_calls(input_);
if (is_partial_ && kimi_k2::is_partial_content_advanced(input_)) {
throw common_chat_msg_partial_exception("partial structured content detected");
}
bool has_function_syntax = input_.find("functions.") != std::string::npos;
bool parsing_succeeded = !tool_calls_json.empty();
if (has_function_syntax && !parsing_succeeded) {
throw std::runtime_error("malformed function call syntax detected");
}
if (!tool_calls_json.empty()) {
for (const auto& tc_json : tool_calls_json) {
try {
common_chat_tool_call tc;
tc.id = tc_json.value("id", "");
if (!tc_json.contains("function") || !tc_json["function"].contains("name")) {
continue;
}
tc.name = tc_json["function"]["name"];
if (tc.name.empty()) {
continue;
}
tc.arguments = tc_json["function"]["arguments"];
if (!is_partial_ && !tc.arguments.empty()) {
try {
auto parsed = json::parse(tc.arguments);
(void)parsed;
} catch (const std::exception&) {
continue;
}
}
add_tool_call(tc);
} catch (const std::exception&) {
continue;
}
}
add_content(kimi_k2::clean_content(input_));
} else {
add_content(input_);
}
std::string common_chat_msg_parser::consume_rest() {
auto rest = input_.substr(pos_);
pos_ = input_.size();
return rest;
}
void common_chat_msg_parser::parse_generic_format() {
add_content(consume_rest());
}
void common_chat_msg_parser::parse_deepseek_r1_format() {
// Delegate to the main chat.cpp function which has the corrected implementation
// This follows the original llama.cpp pattern where chat-parser delegates to chat.cpp
common_chat_parse_deepseek_r1(*this);
}
void common_chat_msg_parser::finish() {
// Any final processing can go here
}
common_chat_msg common_chat_msg_parser::result_and_reset() {
auto msg = result_;
result_ = common_chat_msg();
result_.role = "assistant";
pos_ = 0;
return msg;
}
// Content-only parsing for fallback scenarios
// Format detection from chat template patterns (focused on DeepSeek R1 and Kimi K2)
common_chat_format common_chat_format_detect(const std::string & chat_template) {
if (chat_template.empty()) {
return COMMON_CHAT_FORMAT_GENERIC;
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) {
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
return std::nullopt;
}
// Detect DeepSeek R1 format (following original llama.cpp detection logic)
if (chat_template.find("<tool▁calls▁begin>") != std::string::npos) {
return COMMON_CHAT_FORMAT_DEEPSEEK_R1;
}
// Detect Kimi K2 format (our custom format)
if (chat_template.find("kimi") != std::string::npos ||
chat_template.find("Kimi") != std::string::npos ||
chat_template.find("functions.") != std::string::npos) {
return COMMON_CHAT_FORMAT_KIMI_K2;
}
// Default to generic format for unknown templates
return COMMON_CHAT_FORMAT_GENERIC;
}
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
pos_ = m.groups[0].end;
// Progressive parsing primitive - find literal (following original llama.cpp pattern)
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal(const std::string & literal) {
auto idx = input_.find(literal, pos_);
if (idx != std::string::npos) {
find_regex_result res;
res.prelude = input_.substr(pos_, idx - pos_);
auto end = idx + literal.size();
res.groups.emplace_back(common_string_range{idx, end});
move_to(end);
return res;
if (add_prelude_to_content) {
add_content(prelude);
}
if (is_partial_) {
idx = string_find_partial_stop(input_, literal);
if (idx != std::string::npos && idx >= pos_) {
find_regex_result res;
res.prelude = input_.substr(pos_, idx - pos_);
auto end = input_.size();
res.groups.emplace_back(common_string_range{idx, end});
move_to(end);
return res;
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
if (is_partial()) {
throw common_chat_msg_partial_exception(regex.str());
}
return std::nullopt;
}
return std::nullopt;
return find_regex_result{prelude, m.groups};
}
bool common_chat_msg_parser::consume_spaces() {
bool consumed = false;
while (pos_ < input_.length() && std::isspace(input_[pos_])) {
pos_++;
consumed = true;
common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
if (auto result = try_consume_regex(regex)) {
return *result;
}
return consumed;
throw common_chat_msg_partial_exception(regex.str());
}
void common_chat_msg_parser::set_healing_marker(const std::string & marker) {
healing_marker_ = marker;
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
auto m = regex.search(input_, pos_);
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
return std::nullopt;
}
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
if (is_partial()) {
throw common_chat_msg_partial_exception(regex.str());
}
return std::nullopt;
}
if (m.groups[0].begin != pos_) {
// Didn't match at the current position.
return std::nullopt;
}
pos_ = m.groups[0].end;
return find_regex_result {
/* .prelude = */ "",
m.groups,
};
}
// Enhanced JSON parsing methods (following original llama.cpp patterns exactly)
std::optional<common_json> common_chat_msg_parser::try_consume_json() {
auto it = input_.cbegin() + pos_;
const auto end = input_.cend();
@@ -327,8 +256,8 @@ common_json common_chat_msg_parser::consume_json() {
}
common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args(
const std::vector<std::vector<std::string>>& args_paths,
const std::vector<std::vector<std::string>>& content_paths
const std::vector<std::vector<std::string>> & args_paths,
const std::vector<std::vector<std::string>> & content_paths
) {
if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) {
return *result;
@@ -337,8 +266,8 @@ common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json
}
std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parser::try_consume_json_with_dumped_args(
const std::vector<std::vector<std::string>>& args_paths,
const std::vector<std::vector<std::string>>& content_paths
const std::vector<std::vector<std::string>> & args_paths,
const std::vector<std::vector<std::string>> & content_paths
) {
auto partial = try_consume_json();
if (!partial) {
@@ -366,137 +295,99 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
/* .is_partial = */ false,
};
}
// TODO: Implement full path-based argument dumping logic from original
// For now, return the parsed JSON as-is
return consume_json_result {
partial->json,
/* .is_partial = */ false,
};
}
// Has healing marker - this is partial JSON
// TODO: Implement sophisticated partial JSON handling with path-based dumping
// For now, return partial result
return consume_json_result {
partial->json,
/* .is_partial = */ true,
};
}
bool common_chat_msg_parser::detect_partial_function_call(const std::string& content) {
if (content.empty()) return false;
// Enhanced partial detection patterns
static const std::vector<std::string> partial_patterns = {
"functions",
"functions.",
"<tool_call",
"<tool_call>",
"<invoke",
"<|tool_calls_section_begin|>",
"<|tool_call_begin|>"
};
for (const auto& pattern : partial_patterns) {
if (content.substr(0, pattern.length()) == pattern && content.length() <= pattern.length() + 50) {
return true;
}
}
return false;
}
LOG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
void common_chat_msg_parser::handle_partial_detection() {
if (!is_partial_) return;
// Check for various partial patterns
std::string remaining = input_.substr(pos_);
if (remaining.empty()) return;
// Detect partial function calls
if (detect_partial_function_call(remaining)) {
set_healing_marker(remaining);
throw common_chat_msg_partial_exception("partial function call detected");
}
// Enhanced partial JSON detection
if (remaining.find('{') != std::string::npos) {
size_t brace_pos = remaining.find('{');
std::string json_part = remaining.substr(brace_pos);
// Check if JSON is incomplete
int brace_count = 0;
bool in_string = false;
bool escaped = false;
bool is_incomplete = true;
for (size_t i = 0; i < json_part.length(); i++) {
char c = json_part[i];
if (!escaped) {
if (c == '"' && !in_string) {
in_string = true;
} else if (c == '"' && in_string) {
in_string = false;
} else if (!in_string) {
if (c == '{') brace_count++;
else if (c == '}') brace_count--;
auto found_healing_marker = false;
std::vector<std::string> path;
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
if (is_arguments_path(path)) {
auto arguments = j.dump();
if (is_partial() && !partial->healing_marker.marker.empty()) {
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
if (idx != std::string::npos) {
arguments.resize(idx);
found_healing_marker = true;
}
if (arguments == "\"") {
// This happens because of completing `:"$magic` after `"arguments"`
arguments = "";
}
}
escaped = (!escaped && c == '\\');
if (brace_count == 0) {
is_incomplete = false;
break;
return arguments;
}
if (is_content_path(path)) {
if (!j.is_string()) {
throw std::runtime_error("Content path must be a string");
}
std::string str = j;
auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string
if (idx != std::string::npos) {
str.resize(idx);
found_healing_marker = true;
}
return str;
}
if (is_incomplete) {
set_healing_marker(json_part);
throw common_chat_msg_partial_exception("partial JSON detected");
if (j.is_object()) {
auto obj = json::object();
for (const auto & p : j.items()) {
const auto & key = p.key();
const auto & value = p.value();
const std::string key_str = key; // NOLINT
auto idx = key_str.find(healing_marker_);
if (idx != std::string::npos) {
found_healing_marker = true;
break;
}
path.push_back(key_str);
if (value.is_string()) {
const std::string value_str = value;
if (value_str.find(healing_marker_) != std::string::npos) {
found_healing_marker = true;
if (is_content_path(path)) {
if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) {
// The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair.
obj[key] = remove_unsupported_healings_and_dump_args(value);
}
}
break;
}
obj[key] = value;
} else {
obj[key] = remove_unsupported_healings_and_dump_args(value);
}
path.pop_back();
}
return obj;
}
}
}
// Regex-based parsing methods (ported from original llama.cpp)
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) {
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
return std::nullopt;
}
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
pos_ = m.groups[0].end;
if (add_prelude_to_content) {
add_content(prelude);
}
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
if (is_partial()) {
throw common_chat_msg_partial_exception(regex.str());
if (j.is_array()) {
auto arr = json::array();
for (const auto & value : j) {
if (value.is_string()) {
std::string str = value;
auto idx = str.find(healing_marker_);
if (idx != std::string::npos) {
// Don't heal array values that aren't in the arguments.
found_healing_marker = true;
break;
}
}
arr.push_back(remove_unsupported_healings_and_dump_args(value));
}
return arr;
}
return std::nullopt;
}
return find_regex_result{prelude, m.groups};
return j;
};
auto cleaned = remove_unsupported_healings_and_dump_args(partial->json);
LOG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
return consume_json_result {
cleaned,
/* .is_partial = */ found_healing_marker,
};
}
common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
auto result = try_find_regex(regex);
if (!result) {
throw std::runtime_error("Expected regex not found: " + regex.str());
}
return *result;
void common_chat_msg_parser::clear_tools() {
result_.tool_calls.clear();
}
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
return try_find_regex(regex, pos_, false);
}
void common_chat_msg_parser::consume_literal(const std::string & literal) {
if (!try_consume_literal(literal)) {
throw std::runtime_error("Expected literal not found: " + literal);
}
}
// Get format name for debugging/logging (implemented in chat.cpp)

View File

@@ -1,14 +1,18 @@
// Chat parser with builder pattern for incremental parsing
#pragma once
#include "chat.h"
#include "json-partial.h"
#include "json.hpp"
#include "regex-partial.h"
#include <optional>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
class common_chat_msg_partial_exception : public std::runtime_error {
public:
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
};
class common_chat_msg_parser {
std::string input_;
@@ -20,14 +24,7 @@ class common_chat_msg_parser {
common_chat_msg result_;
public:
struct find_regex_result {
std::string prelude;
std::vector<common_string_range> groups;
};
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
// Accessors
const std::string & input() const { return input_; }
size_t pos() const { return pos_; }
const std::string & healing_marker() const { return healing_marker_; }
@@ -35,14 +32,12 @@ class common_chat_msg_parser {
const common_chat_msg & result() const { return result_; }
const common_chat_syntax & syntax() const { return syntax_; }
// Position manipulation
void move_to(size_t pos) {
if (pos > input_.size()) {
throw std::runtime_error("Invalid position!");
}
pos_ = pos;
}
void move_back(size_t n) {
if (pos_ < n) {
throw std::runtime_error("Can't move back that far!");
@@ -53,84 +48,72 @@ class common_chat_msg_parser {
// Get the substring of the input at the given range
std::string str(const common_string_range & rng) const;
// Content manipulation
// Appends to the result.content field
void add_content(const std::string & content);
// Appends to the result.reasoning_content field
void add_reasoning_content(const std::string & reasoning_content);
// Tool call manipulation
void add_tool_call(const common_chat_tool_call & tool_call);
// Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
bool add_tool_call(const json & tool_call);
bool add_tool_calls(const json & arr);
void clear_tools();
// Parsing utilities
std::string consume_rest();
bool try_consume_literal(const std::string & literal);
void consume_literal(const std::string & literal);
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
// Adds a tool call using the "name", "id" and "arguments" fields of the json object
bool add_tool_call(const nlohmann::ordered_json & tool_call);
// Regex-based parsing methods (new)
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
find_regex_result consume_regex(const common_regex & regex);
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
bool add_tool_calls(const nlohmann::ordered_json & arr);
// Progressive parsing primitives (for Phase 4)
std::optional<find_regex_result> try_find_literal(const std::string & literal);
bool consume_spaces();
void set_healing_marker(const std::string & marker);
// Main parsing entry point
void parse();
// Finishing
void finish();
// Result extraction
common_chat_msg result_and_reset();
bool consume_spaces();
// Advanced JSON parsing (following original llama.cpp patterns)
struct consume_json_result {
json value;
bool is_partial;
void consume_literal(const std::string & literal);
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
std::string consume_rest();
struct find_regex_result {
std::string prelude;
std::vector<common_string_range> groups;
};
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
bool try_consume_literal(const std::string & literal);
std::optional<find_regex_result> try_find_literal(const std::string & literal);
find_regex_result consume_regex(const common_regex & regex);
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
std::optional<common_json> try_consume_json();
common_json consume_json();
consume_json_result consume_json_with_dumped_args(
const std::vector<std::vector<std::string>>& args_paths = {},
const std::vector<std::vector<std::string>>& content_paths = {}
);
std::optional<consume_json_result> try_consume_json_with_dumped_args(
const std::vector<std::vector<std::string>>& args_paths = {},
const std::vector<std::vector<std::string>>& content_paths = {}
);
private:
// Internal parsing helpers
void parse_kimi_k2_format();
void parse_deepseek_r1_format();
void parse_generic_format();
// JSON parsing utilities (enhanced streaming support)
struct json_parse_result {
json value;
bool success;
struct consume_json_result {
nlohmann::ordered_json value;
bool is_partial;
std::string healing_marker;
};
// Partial detection utilities
bool detect_partial_function_call(const std::string& content);
void handle_partial_detection();
/*
Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings.
// Legacy find_literal for compatibility
std::optional<find_regex_result> try_find_literal_legacy(const std::string & literal);
By default, object keys can't be truncated, nor can string values (their corresponding key is removed,
e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}`
But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings
- with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}`
- with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}`
*/
consume_json_result consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths = {},
const std::vector<std::vector<std::string>> & content_paths = {}
);
std::optional<consume_json_result> try_consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths = {},
const std::vector<std::vector<std::string>> & content_paths = {}
);
void clear_tools();
};
// Main parsing function (public API)
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
// Content-only parsing for fallback scenarios (static internal function)

View File

@@ -1,249 +0,0 @@
/*
Copyright 2024 Google LLC
Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.
*/
// SPDX-License-Identifier: MIT
#pragma once
#include "minja.hpp"
#include <json.hpp>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
namespace minja {
class chat_template {
public:
private:
bool supports_tools_ = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool requires_object_arguments_ = false;
bool supports_system_role_ = true;
bool supports_parallel_tool_calls_ = false;
std::string source_;
std::string bos_token_;
std::string eos_token_;
std::shared_ptr<minja::TemplateNode> template_root_;
std::string try_render(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
{
try {
auto prompt = apply(messages, tools, add_generation_prompt, extra_context);
// fprintf(stderr, "Prompt: %s\n", prompt.c_str());
return prompt;
} catch (const std::exception & e) {
// fprintf(stderr, "Error: %s\n", e.what());
return "";
}
}
public:
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
: source_(source), bos_token_(bos_token), eos_token_(eos_token)
{
template_root_ = minja::Parser::parse(source_, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
supports_tools_ = source.find("tools") != std::string::npos;
auto renders_string_arguments =
try_render({
{
{"role", "user"},
{"content", "Hey"}
},
{
{"role", "assistant"},
{"tool_calls", json::array({
{
{"id", "call_1___"},
{"type", "function"},
{"function", {
{"arguments", "{\"code\": \"print('Hello, World!')\"}"},
{"name", "ipython"},
}},
},
})},
}
}, {}, false).find("{\"code\": \"print") != std::string::npos;
if (!renders_string_arguments) {
auto renders_object_arguments =
try_render({
{
{"role", "user"},
{"content", "Hey"}
},
{
{"role", "assistant"},
{"tool_calls", json::array({
{
{"id", "call_1___"},
{"type", "function"},
{"function", {
{"arguments", {
{"code", "print('Hello, World!')"},
}},
{"name", "ipython"},
}},
},
})},
}
}, {}, false).find("{\"code\": \"print") != std::string::npos;
requires_object_arguments_ = renders_object_arguments;
}
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;
supports_system_role_ = try_render({
{{"role", "system"}, {"content", "<System Needle>"}},
{{"role", "user"}, {"content", "Hey"}}
}, {}, false).find("<System Needle>") != std::string::npos;
}
const std::string & source() const { return source_; }
const std::string & bos_token() const { return bos_token_; }
const std::string & eos_token() const { return eos_token_; }
bool supports_tools() const { return supports_tools_; }
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
std::string apply(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
{
json actual_messages;
// First, "fix" messages so they have a chance to be rendered correctly by the template
if (requires_object_arguments_ || !supports_system_role_ || !supports_tools_) {
actual_messages = json::array();
std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
actual_messages.push_back({
{"role", "user"},
{"content", pending_system},
});
pending_system.clear();
}
};
for (const auto & message_ : messages) {
auto message = message_;
if (!message.contains("role") || !message.contains("content")) {
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
}
std::string role = message.at("role");
if (message.contains("tool_calls")) {
if (requires_object_arguments_ || !supports_tools_) {
for (auto & tool_call : message.at("tool_calls")) {
if (tool_call["type"] == "function") {
auto & function = tool_call.at("function");
std::string arguments = function.at("arguments");
function["arguments"] = json::parse(arguments);
}
}
}
if (!supports_tools_) {
auto content = message.at("content");
auto tool_calls = json::array();
for (const auto & tool_call : message.at("tool_calls")) {
if (tool_call.at("type") != "function") {
continue;
}
const auto & function = tool_call.at("function");
auto tc = json {
{"name", function.at("name")},
{"arguments", function.at("arguments")},
};
if (tool_call.contains("id")) {
tc["id"] = tool_call["id"];
}
tool_calls.push_back(tc);
}
auto obj = json {
{"tool_calls", tool_calls},
};
if (!content.is_null() && content != "") {
obj["content"] = content;
}
message["content"] = obj.dump(2);
message.erase("tool_calls");
}
}
if (!supports_tools_ && role == "tool") {
message["role"] = "user";
auto obj = json {
{"tool_response", {
{"tool", message.at("name")},
{"content", message.at("content")},
}},
};
if (message.contains("tool_call_id")) {
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
}
message["content"] = obj.dump(2);
message.erase("name");
}
if (!message["content"].is_null() && !supports_system_role_) {
std::string content = message.at("content");
if (role == "system") {
if (!pending_system.empty()) pending_system += "\n";
pending_system += content;
continue;
} else {
if (role == "user") {
if (!pending_system.empty()) {
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
pending_system.clear();
}
} else {
flush_sys();
}
}
}
actual_messages.push_back(message);
}
flush_sys();
} else {
actual_messages = messages;
}
auto context = minja::Context::make(json({
{"messages", actual_messages},
{"add_generation_prompt", add_generation_prompt},
{"bos_token", bos_token_},
{"eos_token", eos_token_},
}));
if (!tools.is_null()) {
auto tools_val = minja::Value(tools);
context->set("tools", tools_val);
}
if (!extra_context.is_null()) {
for (auto & kv : extra_context.items()) {
minja::Value val(kv.value());
context->set(kv.key(), val);
}
}
return template_root_->render(context);
}
};
} // namespace minja

File diff suppressed because it is too large Load Diff

View File

@@ -1,37 +1,16 @@
// Chat support with builder pattern for llama.cpp compatibility
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
#pragma once
#include "common.h"
#include <functional>
#include <chrono>
#include <string>
#include <vector>
#include <functional>
#include <map>
// Forward declarations
struct common_chat_templates;
// Basic data structures compatible with original llama.cpp
struct common_string_range {
size_t begin;
size_t end;
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
if (begin > end) {
throw std::runtime_error("Invalid range");
}
}
// prevent default ctor
common_string_range() = delete;
bool empty() const {
return begin == end;
}
bool operator==(const common_string_range & other) const {
return begin == other.begin && end == other.end;
}
};
struct common_chat_tool_call {
std::string name;
std::string arguments;
@@ -40,10 +19,6 @@ struct common_chat_tool_call {
bool operator==(const common_chat_tool_call & other) const {
return name == other.name && arguments == other.arguments && id == other.id;
}
bool operator!=(const common_chat_tool_call & other) const {
return !(*this == other);
}
};
struct common_chat_msg_content_part {
@@ -64,11 +39,11 @@ struct common_chat_msg {
std::string tool_name;
std::string tool_call_id;
bool empty() const {
return content.empty() && content_parts.empty() && tool_calls.empty() &&
reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
}
template <class T> T to_json_oaicompat() const;
bool empty() const {
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
}
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
for (auto i = 0u; i < tool_calls.size(); i++) {
if (ids_cache.size() <= i) {
@@ -81,7 +56,6 @@ struct common_chat_msg {
tool_calls[i].id = ids_cache[i];
}
}
bool operator==(const common_chat_msg & other) const {
return role == other.role
&& content == other.content
@@ -91,7 +65,6 @@ struct common_chat_msg {
&& tool_name == other.tool_name
&& tool_call_id == other.tool_call_id;
}
bool operator!=(const common_chat_msg & other) const {
return !(*this == other);
}
@@ -110,10 +83,6 @@ struct common_chat_msg_diff {
&& tool_call_index == other.tool_call_index
&& tool_call_delta == other.tool_call_delta;
}
bool operator!=(const common_chat_msg_diff & other) const {
return !(*this == other);
}
};
struct common_chat_tool {
@@ -131,50 +100,110 @@ enum common_chat_tool_choice {
enum common_chat_format {
COMMON_CHAT_FORMAT_CONTENT_ONLY,
COMMON_CHAT_FORMAT_GENERIC,
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_KIMI_K2, // Our custom format (keep last for backward compatibility)
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_AUTO,
COMMON_REASONING_FORMAT_DEEPSEEK,
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY,
struct common_chat_templates_inputs {
std::vector<common_chat_msg> messages;
std::string grammar;
std::string json_schema;
bool add_generation_prompt = true;
bool use_jinja = true;
// Parameters below only supported when use_jinja is true
std::vector<common_chat_tool> tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
std::map<std::string, std::string> chat_template_kwargs;
bool add_bos = false;
bool add_eos = false;
};
struct common_chat_params {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
std::string prompt;
std::string grammar;
bool grammar_lazy = false;
bool thinking_forced_open = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
};
struct common_chat_syntax {
common_chat_format format = COMMON_CHAT_FORMAT_KIMI_K2;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; //COMMON_REASONING_FORMAT_NONE;
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false;
bool thinking_forced_open = false;
bool enable_thinking = false;
bool enable_tool_calls = true;
bool reasoning_in_content = false;
bool thinking_forced_open = false;
bool parse_tool_calls = true;
};
// Exception for partial parsing
class common_chat_msg_partial_exception : public std::runtime_error {
public:
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
// Bridge functions to integrate with existing ik_llama.cpp system
// TODO: Uncomment and implement during integration phase
// common_chat_msg ik_to_common_msg(const struct ik_chat_msg & ik_msg);
// struct ik_chat_msg common_to_ik_msg(const common_chat_msg & common_msg);
void common_chat_templates_free(struct common_chat_templates * tmpls);
// Format detection from chat template
common_chat_format common_chat_format_detect(const std::string & chat_template);
const char* common_chat_format_name(common_chat_format format);
const char* common_reasoning_format_name(common_reasoning_format format);
struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
// Main parsing function (entry point for original llama.cpp compatibility)
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
// Forward declare parser class
class common_chat_msg_parser;
common_chat_templates_ptr common_chat_templates_init(
const struct llama_model * model,
const std::string & chat_template_override,
const std::string & bos_token_override = "",
const std::string & eos_token_override = "");
// Format-specific parsing functions (accessible from chat-parser)
void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder);
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
struct common_chat_params common_chat_templates_apply(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs);
// Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(
const struct common_chat_templates * tmpls,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja);
// Returns an example of formatted chat
std::string common_chat_format_example(
const struct common_chat_templates * tmpls,
bool use_jinja);
const char* common_chat_format_name(common_chat_format format);
const char* common_reasoning_format_name(common_reasoning_format format);
common_reasoning_format common_reasoning_format_from_name(const std::string& format);
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
// Parses a JSON array of messages in OpenAI's chat completion API format.
// T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
// T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);

View File

@@ -13,10 +13,10 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "json-schema-to-grammar.h"
#include "llama-vocab.h"
#include "llama.h"
#include "chat-template.hpp"
#include "chat.h"
#include "json-schema-to-grammar.h"
#include <algorithm>
#include <cinttypes>
#include <climits>
@@ -230,13 +230,13 @@ void gpt_params_handle_model_default(gpt_params & params) {
}
params.hf_file = params.model;
} else if (params.model.empty()) {
params.model = fs_get_cache_file(string_split(params.hf_file, '/').back());
params.model = fs_get_cache_file(string_split(params.hf_file, "/").back());
}
} else if (!params.model_url.empty()) {
if (params.model.empty()) {
auto f = string_split(params.model_url, '#').front();
f = string_split(f, '?').front();
params.model = fs_get_cache_file(string_split(f, '/').back());
auto f = string_split(params.model_url, "#").front();
f = string_split(f, "?").front();
params.model = fs_get_cache_file(string_split(f, "/").back());
}
} else if (params.model.empty()) {
params.model = DEFAULT_MODEL_PATH;
@@ -295,7 +295,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.tensor_buft_overrides.push_back({nullptr, nullptr});
}
if (!params.chat_template.empty() && !llama_chat_verify_template(nullptr, params.chat_template, params.use_jinja)) {
if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s%s\n",
params.chat_template.c_str(),
@@ -599,7 +599,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
if (arg == "--samplers") {
CHECK_ARG
const auto sampler_names = string_split(argv[i], ';');
const auto sampler_names = string_split(argv[i], ";");
sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true);
return true;
}
@@ -1486,6 +1486,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
return true;
}
if (arg == "--reasoning-budget") {
CHECK_ARG
params.reasoning_budget = std::stoi(argv[i]);
return true;
}
if (arg == "--sql-save-file") {
CHECK_ARG
params.sql_save_file = argv[i];
@@ -1498,7 +1503,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
if (arg == "--chat-template") {
CHECK_ARG
if (!llama_chat_verify_template(nullptr, argv[i], false)) {
if (!common_chat_verify_template(argv[i], true)) {
fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]);
fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n");
invalid_param = true;
@@ -1510,9 +1515,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
if (arg == "--chat-template-file") {
CHECK_ARG
std::string chat_template = read_file(std::string(argv[i]));
if (!llama_chat_verify_template(nullptr, chat_template, false)) {
if (!common_chat_verify_template(chat_template, true)) {
fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]);
fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n");
invalid_param = true;
return true;
}
@@ -1523,6 +1527,26 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.use_jinja = true;
return true;
}
if (arg == "--chat-template-kwargs") {
CHECK_ARG
std::string value = argv[i];
auto parsed = json::parse(value);
for (const auto& item : parsed.items()) {
params.default_template_kwargs[item.key()] = item.value().dump();
}
return true;
}
if (arg == "--reasoning-format") {
CHECK_ARG
std::string value = argv[i];
params.reasoning_format = common_reasoning_format_from_name(value);
return true;
}
if (arg == "--no-prefill-assistant") {
CHECK_ARG
params.prefill_assistant = false;
return true;
}
if (arg == "--slot-prompt-similarity" || arg == "-sps") {
CHECK_ARG
params.slot_prompt_similarity = std::stof(argv[i]);
@@ -1831,11 +1855,22 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "main", " --cfg-negative-prompt-file FNAME",
"negative prompt file to use for guidance" });
options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale });
options.push_back({ "main", " --chat-template JINJA_TEMPLATE",
options.push_back({ "main", " --jinja",
"set custom jinja chat template (default: template taken from model's metadata)\n"
"if suffix/prefix are specified, template will be disabled\n"
"only commonly used templates are accepted:\n"
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
options.push_back({ "main", " --chat-template JINJA_TEMPLATE",
"use jinja template for chat (default: disabled)\n" });
options.push_back({ "main", " --reasoning-format FORMAT",
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
"- none: leaves thoughts unparsed in `message.content`\n"
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
"(default: none)", });
options.push_back({ "main", " --chat-template-kwargs JSON", "sets additional params for the json template parser"});
options.push_back({ "main", " --reasoning-budget N", "controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)" });
options.push_back({ "main", " --no-prefill-assistant", "whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)\n"
"when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled\n" });
options.push_back({ "grammar" });
options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" });
@@ -2095,42 +2130,66 @@ std::string string_format(const char* fmt, ...) {
return std::string(buf.data(), size);
}
std::string regex_escape(const std::string& s) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
return std::regex_replace(s, special_chars, "\\$0");
}
std::vector<std::string> string_split(std::string input, char separator) {
std::vector<std::string> parts;
size_t separator_pos = input.find(separator);
while (separator_pos != std::string::npos) {
std::string part = input.substr(0, separator_pos);
parts.emplace_back(part);
input = input.substr(separator_pos + 1);
separator_pos = input.find(separator);
std::string string_join(const std::vector<std::string>& values, const std::string& separator) {
std::ostringstream result;
for (size_t i = 0; i < values.size(); ++i) {
if (i > 0) {
result << separator;
}
result << values[i];
}
parts.emplace_back(input);
return result.str();
}
std::vector<std::string> string_split(const std::string& str, const std::string& delimiter) {
std::vector<std::string> parts;
size_t start = 0;
size_t end = str.find(delimiter);
while (end != std::string::npos) {
parts.push_back(str.substr(start, end - start));
start = end + delimiter.length();
end = str.find(delimiter, start);
}
parts.push_back(str.substr(start));
return parts;
}
std::string string_join(const std::vector<std::string> & strs, const std::string & delimiter) {
if (strs.empty()) {
return "";
std::vector<std::string> string_split(const std::string& str, char delim) {
std::vector<std::string> values;
std::istringstream str_stream(str);
std::string token;
while (std::getline(str_stream, token, delim)) {
std::string value;
std::istringstream token_stream(token);
token_stream >> value;
values.push_back(value);
}
return values;
}
std::ostringstream oss;
for (size_t i = 0; i < strs.size(); ++i) {
if (i > 0) {
oss << delimiter;
}
oss << strs[i];
}
return oss.str();
static bool is_utf8_whitespace(uint8_t c) {
// Basic ASCII whitespace
if (c <= 0x7F) return isspace(c);
// Else: Not whitespace (or you'd need a full Unicode table)
return false;
}
std::string string_strip(const std::string & str) {
size_t start = 0;
size_t end = str.size();
while (start < end && std::isspace(str[start])) {
while (start < end && is_utf8_whitespace(str[start])) {
start++;
}
while (end > start && std::isspace(str[end - 1])) {
while (end > start && is_utf8_whitespace(str[end - 1])) {
end--;
}
return str.substr(start, end - start);
@@ -2163,6 +2222,25 @@ void string_replace_all(std::string & s, const std::string & search, const std::
}
}
bool string_ends_with(const std::string_view& str, const std::string_view& suffix) {
return str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
}
size_t string_find_partial_stop(const std::string_view& str, const std::string_view& stop) {
if (!str.empty() && !stop.empty()) {
const char text_last_char = str.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const auto current_partial = stop.substr(0, char_index + 1);
if (string_ends_with(str, current_partial)) {
return str.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}
void string_process_escapes(std::string & input) {
std::size_t input_len = input.length();
std::size_t output_idx = 0;
@@ -3140,154 +3218,172 @@ bool llama_should_add_bos_token(const llama_model * model) {
//
// Chat template utils
//
//
//bool llama_chat_verify_template(const struct llama_model* model, const std::string& tmpl, bool use_jinja) {
// if (use_jinja) {
// try {
// auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
// common_chat_inputs inputs;
// inputs.messages = json::array({ {
// {"role", "user"},
// {"content", "test"},
// } });
// common_chat_params_init(chat_template, inputs);
// return true;
// }
// catch (const std::exception& e) {
// fprintf(stdout,"%s: failed to apply template: %s\n", __func__, e.what());
// return false;
// }
// }
// llama_chat_message chat[] = { {"user", "test"} };
// const int res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
// return res >= 0;
//}
bool llama_chat_verify_template(const struct llama_model* model, const std::string& tmpl, bool use_jinja) {
if (use_jinja) {
try {
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
chat_template.apply({ {
{"role", "user"},
{"content", "test"},
} }, json(), true);
return true;
}
catch (const std::exception& e) {
fprintf(stdout,"%s: failed to apply template: %s\n", __func__, e.what());
return false;
}
}
llama_chat_message chat[] = {{"user", "test"}};
const int res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}
//std::string llama_chat_apply_template(const struct llama_model * model,
// const common_chat_template& tmpl,
// const std::vector<common_chat_msg> & msgs,
// bool add_ass,
// bool use_jinja) {
// if (use_jinja) {
// auto messages = json::array();
// for (const auto& msg : msgs) {
// messages.push_back({ {"role", msg.role}, {"content", msg.content} });
// }
// common_chat_inputs inputs;
// inputs.messages = messages;
// inputs.add_generation_prompt = add_ass;
// return common_chat_params_init(tmpl, inputs).prompt;
// }
// int alloc_size = 0;
// std::vector<llama_chat_message> chat;
// for (auto & msg : msgs) {
// chat.push_back({msg.role.c_str(), msg.content.c_str()});
// alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
// }
//
// std::vector<char> buf(alloc_size);
//
// // run the first time to get the total output length
// int32_t res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// // error: chat template is not supported
// if (res < 0) {
// // if the custom "tmpl" is not supported, we throw an error
// // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
// throw std::runtime_error("this custom template is not supported");
// }
//
// // if it turns out that our buffer is too small, we resize it
// if ((size_t)res > buf.size()) {
// buf.resize(res);
// res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// }
//
// std::string formatted_chat(buf.data(), res);
// return formatted_chat;
//}
////
//std::string llama_chat_format_single(const struct llama_model * model,
// const common_chat_template& tmpl,
// const std::vector<common_chat_msg> & past_msg,
// const common_chat_msg & new_msg,
// bool add_ass,
// bool use_jinja) {
// std::ostringstream ss;
// auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja);
// std::vector<common_chat_msg> chat_new(past_msg);
// // if the past_msg ends with a newline, we must preserve it in the formatted version
// if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
// ss << "\n";
// };
// // format chat with new_msg
// chat_new.push_back(new_msg);
// auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja);
// // get the diff part
// ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
// return ss.str();
//}
std::string llama_chat_apply_template(const struct llama_model * model,
const common_chat_template& tmpl,
const std::vector<llama_chat_msg> & msgs,
bool add_ass,
bool use_jinja) {
if (use_jinja) {
auto messages = json::array();
for (const auto& msg : msgs) {
messages.push_back({ {"role", msg.role}, {"content", msg.content} });
}
return tmpl.apply(messages, /* tools= */ json(), add_ass);
}
int alloc_size = 0;
std::vector<llama_chat_message> chat;
for (auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()});
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
}
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
}
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(model, tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
}
std::string formatted_chat(buf.data(), res);
return formatted_chat;
}
std::string llama_chat_format_single(const struct llama_model * model,
const common_chat_template& tmpl,
const std::vector<llama_chat_msg> & past_msg,
const llama_chat_msg & new_msg,
bool add_ass,
bool use_jinja) {
std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja);
std::vector<llama_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
ss << "\n";
};
// format chat with new_msg
chat_new.push_back(new_msg);
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja);
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
}
std::string llama_chat_format_example(const struct llama_model * model, const common_chat_template& tmpl, bool use_jinja) {
std::vector<llama_chat_msg> msgs = {
{"system", "You are a helpful assistant"},
{"user", "Hello"},
{"assistant", "Hi there"},
{"user", "How are you?"},
};
return llama_chat_apply_template(model, tmpl, msgs, true, use_jinja);
}
common_chat_templates llama_chat_templates_from_model(const struct llama_model* model, const std::string& chat_template_override)
{
auto vocab = llama_model_get_vocab(model);
std::string default_template_src = chat_template_override;
std::string template_tool_use_src = chat_template_override;
bool has_explicit_template = !chat_template_override.empty();
if (chat_template_override.empty()) {
auto str = llama_model_chat_template(model, /* name */ nullptr);
if (str) {
default_template_src = str;
has_explicit_template = true;
}
str = llama_model_chat_template(model, /* name */ "tool_use");
if (str) {
template_tool_use_src = str;
has_explicit_template = true;
}
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!template_tool_use_src.empty()) {
default_template_src = template_tool_use_src;
}
else {
default_template_src = R"(
{%- for message in messages -%}
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- "<|im_start|>assistant\n" -}}
{%- endif -%}
)";
}
}
const auto get_token = [&](llama_token token, const char* name, const char* jinja_variable_name) {
if (token == LLAMA_TOKEN_NULL) {
if (default_template_src.find(jinja_variable_name) != std::string::npos
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
fprintf(stdout, "%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
}
return std::string();
}
else {
return llama_token_to_piece(model, token, true);
}
};
auto token_bos = get_token(llama_token_bos(model), "BOS", "bos_token");
auto token_eos = get_token(llama_token_eos(model), "EOS", "eos_token");
return {
has_explicit_template,
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
template_tool_use_src.empty()
? nullptr
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos)
};
}
//std::string llama_chat_format_example(const struct llama_model * model, const common_chat_template& tmpl, bool use_jinja) {
// std::vector<common_chat_msg> msgs = {
// {"system", "You are a helpful assistant", {}},
// {"user", "Hello", {}},
// {"assistant", "Hi there", {}},
// {"user", "How are you?", {}},
// };
// return llama_chat_apply_template(model, tmpl, msgs, true, use_jinja);
//}
//
//#define CHATML_TEMPLATE_SRC \
// "{%- for message in messages -%}\n" \
// " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
// "{%- endfor -%}\n" \
// "{%- if add_generation_prompt -%}\n" \
// " {{- '<|im_start|>assistant\n' -}}\n" \
// "{%- endif -%}"
//
//common_chat_templates llama_chat_templates_from_model(const struct llama_model* model, const std::string& chat_template_override)
//{
// std::string default_template_src;
// std::string template_tool_use_src;
// bool has_explicit_template = !chat_template_override.empty();
// if (chat_template_override.empty()) {
// auto str = llama_model_chat_template(model, /* name */ nullptr);
// if (str) {
// default_template_src = str;
// has_explicit_template = true;
// }
// str = llama_model_chat_template(model, /* name */ "tool_use");
// if (str) {
// template_tool_use_src = str;
// has_explicit_template = true;
// }
// }
// else {
// default_template_src = chat_template_override;
// }
// if (default_template_src.empty() || default_template_src == "chatml") {
// if (!template_tool_use_src.empty()) {
// default_template_src = template_tool_use_src;
// }
// else {
// default_template_src = CHATML_TEMPLATE_SRC;
// }
// }
// auto vocab = llama_model_get_vocab(model);
// const auto get_token = [&](llama_token token, const char* name, const char* jinja_variable_name) {
// if (token == LLAMA_TOKEN_NULL) {
// if (default_template_src.find(jinja_variable_name) != std::string::npos
// || template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
// fprintf(stdout, "%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
// }
// return std::string();
// }
// else {
// return llama_token_to_piece(model, token, true);
// }
// };
// auto token_bos = get_token(llama_token_bos_impl(*vocab), "BOS", "bos_token");
// auto token_eos = get_token(llama_token_eos_impl(*vocab), "EOS", "eos_token");
// try {
// return {
// has_explicit_template,
// std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
// template_tool_use_src.empty()
// ? nullptr
// : std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
// };
// }
// catch (const std::exception& e) {
// LOG("%s: failed to parse chat template: %s\n", __func__, e.what());
// return {
// has_explicit_template,
// std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
// nullptr,
// };
// }
//}
//
// KV cache utils
@@ -3778,27 +3874,3 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
}
// Additional string utilities for builder pattern compatibility
bool string_starts_with(const std::string & str, const std::string & prefix) {
return str.rfind(prefix, 0) == 0;
}
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
if (!str.empty() && !stop.empty()) {
const char text_last_char = str.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const auto current_partial = stop.substr(0, char_index + 1);
if (string_ends_with(str, current_partial)) {
return str.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}

View File

@@ -15,14 +15,18 @@
#define LOG_NO_FILE_LINE_FUNCTION
#include "log.h"
#include <set>
#include <cmath>
#include <string>
#include <sstream>
#include <string_view>
#include <vector>
#include <random>
#include <thread>
#include <unordered_map>
#include <tuple>
#include <map>
#include <sstream>
#ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\'
@@ -74,6 +78,14 @@ enum dimre_method {
DIMRE_METHOD_MEAN,
};
// reasoning API response format (not to be confused as chat template's reasoning format)
enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_AUTO,
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
};
struct gpt_params {
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed
@@ -240,13 +252,21 @@ struct gpt_params {
bool use_jinja = false; // NOLINT
std::string system_prompt = "";
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
int reasoning_budget = -1;
bool prefill_assistant = true;
std::vector<std::string> api_keys;
std::string ssl_file_key = "";
std::string ssl_file_cert = "";
bool endpoint_slots = true;
std::map<std::string, std::string> default_template_kwargs;
// "advanced" endpoints are disabled by default for better security
bool webui = true;
bool endpoint_slots = false;
bool endpoint_props = false; // only control POST requests, not GET
bool endpoint_metrics = false;
bool log_json = false;
@@ -314,19 +334,24 @@ std::string gpt_params_get_system_info(const gpt_params & params);
//
// String utils
//
std::vector<std::string> string_split(std::string input, char separator);
std::string string_join(const std::vector<std::string> & strs, const std::string & delimiter);
std::string string_join(const std::vector<std::string>& values, const std::string& separator);
std::string string_strip(const std::string & str);
std::string string_get_sortable_timestamp();
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
static bool string_starts_with(const std::string& str,
const std::string& prefix) { // While we wait for C++20's std::string::starts_with...
return str.rfind(prefix, 0) == 0;
}
// Additional string utilities for builder pattern compatibility
bool string_starts_with(const std::string & str, const std::string & prefix);
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
std::vector<std::string> string_split(const std::string& str, const std::string& delimiter);
std::vector<std::string> string_split(const std::string& str, char delim);
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
// While we wait for C++20's std::string::ends_with...
bool string_ends_with(const std::string_view& str, const std::string_view& suffix);
size_t string_find_partial_stop(const std::string_view& str, const std::string_view& stop);
std::string regex_escape(const std::string& s);
template<class T>
static std::vector<T> string_split(const std::string & str, char delim) {
@@ -342,6 +367,22 @@ static std::vector<T> string_split(const std::string & str, char delim) {
return values;
}
template<>
std::vector<std::string> string_split<std::string>(const std::string& input, char separator)
{
std::vector<std::string> parts;
size_t begin_pos = 0;
size_t separator_pos = input.find(separator);
while (separator_pos != std::string::npos) {
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
parts.emplace_back(part);
begin_pos = separator_pos + 1;
separator_pos = input.find(separator, begin_pos);
}
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
return parts;
}
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input);
@@ -432,52 +473,59 @@ bool llama_should_add_bos_token(const llama_model * model);
//
// Chat template utils
//
//struct common_tool_call {
// std::string name;
// std::string arguments;
// std::string id;
//};
//
//// same with llama_chat_message, but uses std::string
//struct common_chat_msg {
// std::string role;
// std::string content;
// std::vector<common_tool_call> tool_calls;
// std::string reasoning_content = "";
//};
// same with llama_chat_message, but uses std::string
struct llama_chat_msg {
std::string role;
std::string content;
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool llama_chat_verify_template(const struct llama_model* , const std::string& tmpl, bool use_jinja);
namespace minja {
class chat_template;
}
typedef minja::chat_template common_chat_template;
struct common_chat_templates {
bool has_explicit_template; // Model had builtin template or template overridde was specified.
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
std::unique_ptr<common_chat_template> template_tool_use;
};
// CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml
// If the custom "tmpl" is not supported, we throw an error
std::string llama_chat_apply_template(
const struct llama_model* model,
const common_chat_template& tmpl,
const std::vector< llama_chat_msg>& chat,
bool add_ass,
bool use_jinja);
// Format single message, while taking into account the position of that message in chat history
std::string llama_chat_format_single(const struct llama_model* model,
const common_chat_template& tmpl,
const std::vector< llama_chat_msg>& past_msg,
const llama_chat_msg& new_msg,
bool add_ass,
bool use_jinja);
// Returns an example of formatted chat
std::string llama_chat_format_example(const struct llama_model* model,
const common_chat_template& tmpl, bool use_jinja);
common_chat_templates llama_chat_templates_from_model(const struct llama_model* model, const std::string& chat_template_override);
//// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
//bool llama_chat_verify_template(const struct llama_model* , const std::string& tmpl, bool use_jinja);
//
//namespace minja {
// class chat_template;
//}
//
//typedef minja::chat_template common_chat_template;
//
//struct common_chat_templates {
// bool has_explicit_template; // Model had builtin template or template overridde was specified.
// std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
// std::unique_ptr<common_chat_template> template_tool_use;
//};
//
//
//// CPP wrapper for llama_chat_apply_template
//// If the built-in template is not supported, we default to chatml
//// If the custom "tmpl" is not supported, we throw an error
//std::string llama_chat_apply_template(
// const struct llama_model* model,
// const common_chat_template& tmpl,
// const std::vector< common_chat_msg>& chat,
// bool add_ass,
// bool use_jinja);
//
//// Format single message, while taking into account the position of that message in chat history
//std::string llama_chat_format_single(const struct llama_model* model,
// const common_chat_template& tmpl,
// const std::vector< common_chat_msg>& past_msg,
// const common_chat_msg& new_msg,
// bool add_ass,
// bool use_jinja);
//
//// Returns an example of formatted chat
//std::string llama_chat_format_example(const struct llama_model* model,
// const common_chat_template& tmpl, bool use_jinja);
//
//common_chat_templates llama_chat_templates_from_model(const struct llama_model* model, const std::string& chat_template_override);
//

View File

@@ -383,10 +383,13 @@ namespace grammar_parser {
}
}
}
state.success = true;
return state;
} catch (const std::exception & err) {
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
return parse_state();
fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
parse_state state;
state.success = false;
return state;
}
}

View File

@@ -22,6 +22,7 @@ namespace grammar_parser {
std::vector<std::vector<llama_grammar_element>> rules;
std::vector<const llama_grammar_element *> c_rules();
bool success;
};
parse_state parse(const char * src);

View File

@@ -1,13 +1,10 @@
#include "json-partial.h"
#include <json-partial.h>
#include "ggml.h"
#include "log.h"
#include "../ggml/include/ggml.h"
#include "../examples/server/utils.hpp"
#include "json.hpp"
#include <string>
#include <json.hpp>
using json = nlohmann::ordered_json;
enum common_json_stack_element_type {
@@ -129,7 +126,7 @@ bool common_json_parse(
return true;
} catch (const std::exception & ex) {
// No, needs healing.
LOG_VERBOSE("Failed to parse up to error", {{"error", ex.what()}, {"content", std::string(it, temptative_end)}});
LOG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
}
auto can_parse = [](const std::string & str) {
try {

View File

@@ -1,6 +1,5 @@
#pragma once
#include "json.hpp"
#include <json.hpp>
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
struct common_healing_marker {

View File

@@ -267,7 +267,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
throw std::runtime_error("At least one of min_value or max_value must be set");
}
const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}";
struct BuiltinRule {
std::string content;
@@ -389,6 +389,7 @@ static std::string format_literal(const std::string & literal) {
class SchemaConverter {
private:
friend std::string build_grammar(const std::function<void(const common_grammar_builder&)>& cb, const common_grammar_options& options);
std::function<json(const std::string &)> _fetch_json;
bool _dotall;
std::map<std::string, std::string> _rules;
@@ -1035,11 +1036,35 @@ public:
}
};
std::string json_schema_to_grammar(const json & schema) {
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
auto copy = schema;
converter.resolve_refs(copy, "input");
converter.visit(copy, "");
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
#ifdef LLAMA_USE_LLGUIDANCE
if (!force_gbnf) {
return "%llguidance {}\nstart: %json " + schema.dump();
}
#else
(void)force_gbnf;
#endif // LLAMA_USE_LLGUIDANCE
return build_grammar([&](const common_grammar_builder& callbacks) {
auto copy = schema;
callbacks.resolve_refs(copy);
callbacks.add_schema("", copy);
});
}
std::string build_grammar(const std::function<void(const common_grammar_builder&)>& cb, const common_grammar_options& options) {
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
common_grammar_builder builder{
/* .add_rule = */ [&](const std::string& name, const std::string& rule) {
return converter._add_rule(name, rule);
},
/* .add_schema = */ [&](const std::string& name, const nlohmann::ordered_json& schema) {
return converter.visit(schema, name == "root" ? "" : name);
},
/* .resolve_refs = */ [&](nlohmann::ordered_json& schema) {
converter.resolve_refs(schema, "");
}
};
cb(builder);
converter.check_errors();
return converter.format_grammar();
}

View File

@@ -5,4 +5,17 @@
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
bool force_gbnf = false);
struct common_grammar_builder {
std::function<std::string(const std::string&, const std::string&)> add_rule;
std::function<std::string(const std::string&, const nlohmann::ordered_json&)> add_schema;
std::function<void(nlohmann::ordered_json&)> resolve_refs;
};
struct common_grammar_options {
bool dotall = false;
};
std::string build_grammar(const std::function<void(const common_grammar_builder&)>& cb, const common_grammar_options& options = {});

270
common/llguidance.cpp Normal file
View File

@@ -0,0 +1,270 @@
#include "sampling.h"
#include "log.h"
#ifdef LLAMA_USE_LLGUIDANCE
# include "llguidance.h"
# include <cmath>
struct llama_sampler_llg {
const llama_vocab * vocab;
std::string grammar_kind;
std::string grammar_data;
LlgTokenizer * tokenizer;
LlgConstraint * grammar;
LlgMaskResult llg_res;
bool has_llg_res;
};
static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
const char * grammar_data) {
LlgConstraintInit cinit;
llg_constraint_init_set_defaults(&cinit, tokenizer);
const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
if (log_level && *log_level) {
cinit.log_stderr_level = atoi(log_level);
}
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
if (llg_get_error(c)) {
LOG_ERR("llg error: %s\n", llg_get_error(c));
llg_free_constraint(c);
return nullptr;
}
return c;
}
static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
return "llguidance";
}
static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
LlgCommitResult res;
llg_commit_token(ctx->grammar, token, &res);
ctx->has_llg_res = false;
}
}
static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
if (!ctx->has_llg_res) {
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
ctx->has_llg_res = true;
} else {
LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
llg_free_constraint(ctx->grammar);
ctx->grammar = nullptr;
}
}
if (ctx->has_llg_res) {
if (ctx->llg_res.is_stop) {
for (size_t i = 0; i < cur_p->size; ++i) {
if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
cur_p->data[i].logit = -INFINITY;
}
}
} else {
const uint32_t * mask = ctx->llg_res.sample_mask;
for (size_t i = 0; i < cur_p->size; ++i) {
auto token = cur_p->data[i].id;
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
cur_p->data[i].logit = -INFINITY;
}
}
}
}
}
}
static void llama_sampler_llg_reset(llama_sampler * smpl) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (!ctx->grammar) {
return;
}
auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
llg_free_constraint(ctx->grammar);
ctx->grammar = grammar_new;
ctx->has_llg_res = false;
}
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
// copy the state
{
auto * result_ctx = (llama_sampler_llg *) result->ctx;
if (ctx->grammar) {
result_ctx->grammar_kind = ctx->grammar_kind;
result_ctx->grammar_data = ctx->grammar_data;
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
}
}
return result;
}
static void llama_sampler_llg_free(llama_sampler * smpl) {
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
llg_free_constraint(ctx->grammar);
llg_free_tokenizer(ctx->tokenizer);
}
delete ctx;
}
static llama_sampler_i llama_sampler_llg_i = {
/* .name = */ llama_sampler_llg_name,
/* .accept = */ llama_sampler_llg_accept_impl,
/* .apply = */ llama_sampler_llg_apply,
/* .reset = */ llama_sampler_llg_reset,
/* .clone = */ llama_sampler_llg_clone,
/* .free = */ llama_sampler_llg_free,
};
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
uint32_t * output_tokens, size_t output_tokens_len) {
const llama_vocab * vocab = (const llama_vocab *) user_data;
int r = 0;
try {
r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
true);
} catch (const std::exception & e) {
GGML_ABORT("llama_tokenize failed: %s\n", e.what());
}
if (r < 0) {
return -r;
}
return r;
}
static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
// TODO store the tokenizer in the vocab somehow
static const llama_vocab * vocab_cache;
static LlgTokenizer * tokenizer_cache;
if (vocab_cache == vocab) {
return llg_clone_tokenizer(tokenizer_cache);
}
auto tok_eos = llama_vocab_eot(vocab);
if (tok_eos == LLAMA_TOKEN_NULL) {
tok_eos = llama_vocab_eos(vocab);
}
size_t vocab_size = llama_vocab_n_tokens(vocab);
auto token_lens = new uint32_t[vocab_size];
// we typically have ~7 bytes per token; let's go on the safe side here
auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
auto token_bytes = new uint8_t[token_bytes_size];
size_t offset = 0;
for (size_t i = 0; i < vocab_size; i++) {
size_t max_token = 1024;
if (token_bytes_size - offset < max_token) {
GGML_ABORT("token_bytes buffer too small\n");
}
llama_token token = i;
auto dp = (char *) token_bytes + offset;
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
if (size < 0) {
GGML_ABORT("llama_detokenize failed\n");
}
if (size == 0) {
size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
if (size < 0) {
GGML_ABORT("llama_detokenize failed\n");
}
if (size != 0) {
*dp = '\xff'; // special token prefix marker
size += 1;
}
}
token_lens[i] = size;
offset += size;
}
LlgTokenizerInit tinit = {
/* .vocab_size = */ (uint32_t) vocab_size,
/* .tok_eos = */ (uint32_t) tok_eos,
/* .token_lens = */ token_lens,
/* .token_bytes = */ token_bytes,
/* .tokenizer_json = */ nullptr,
/* .tokenize_assumes_string = */ true,
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
/* .use_approximate_greedy_tokenize_fn = */ false,
/* .tokenize_user_data = */ vocab,
};
char error_buffer[1024];
LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
delete[] token_bytes;
delete[] token_lens;
if (tokenizer == nullptr) {
LOG_ERR("llg tokenizer error: %s\n", error_buffer);
return tokenizer;
}
if (tokenizer_cache) {
llg_free_tokenizer(tokenizer_cache);
}
vocab_cache = vocab;
tokenizer_cache = tokenizer;
return llg_clone_tokenizer(tokenizer_cache);
}
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
const char * grammar_data) {
auto * ctx = new llama_sampler_llg;
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
*ctx = {
/* .vocab = */ vocab,
/* .grammar_kind = */ grammar_kind,
/* .grammar_data = */ grammar_data,
/* .tokenizer = */ tokenizer,
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
/* .llg_res = */ {},
/* .has_llg_res = */ false,
};
} else {
*ctx = {
/* .vocab = */ vocab,
/* .grammar_kind = */ {},
/* .grammar_data = */ {},
/* .tokenizer = */ nullptr,
/* .grammar = */ nullptr,
/* .llg_res = */ {},
/* .has_llg_res = */ false,
};
}
return new llama_sampler{
/* .iface = */ &llama_sampler_llg_i,
/* .ctx = */ ctx,
};
}
#else
llama_grammar * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
LOG("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
return nullptr;
}
#endif // LLAMA_USE_LLGUIDANCE

View File

@@ -1,5 +1,4 @@
#pragma once
#include <chrono>
#include <cstring>
#include <sstream>

View File

@@ -0,0 +1,549 @@
/*
Copyright 2024 Google LLC
Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.
*/
// SPDX-License-Identifier: MIT
#pragma once
#include "minja.hpp"
#include <chrono>
#include <cstddef>
#include <cstdio>
#include <ctime>
#include <exception>
#include <iomanip>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include <json.hpp>
using json = nlohmann::ordered_json;
namespace minja {
struct chat_template_caps {
bool supports_tools = false;
bool supports_tool_calls = false;
bool supports_tool_responses = false;
bool supports_system_role = false;
bool supports_parallel_tool_calls = false;
bool supports_tool_call_id = false;
// meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool requires_object_arguments = false;
// CohereForAI/c4ai-command-r-plus simple variant
bool requires_non_null_content = false;
// MiniMaxAI/MiniMax-Text-01 special
bool requires_typed_content = false;
};
struct chat_template_inputs {
nlohmann::ordered_json messages;
nlohmann::ordered_json tools;
bool add_generation_prompt = true;
nlohmann::ordered_json extra_context;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
};
struct chat_template_options {
bool apply_polyfills = true;
bool use_bos_token = true;
bool use_eos_token = true;
bool define_strftime_now = true;
bool polyfill_tools = true;
bool polyfill_tool_call_examples = true;
bool polyfill_tool_calls = true;
bool polyfill_tool_responses = true;
bool polyfill_system_role = true;
bool polyfill_object_arguments = true;
bool polyfill_typed_content = true;
};
class chat_template {
private:
chat_template_caps caps_;
std::string source_;
std::string bos_token_;
std::string eos_token_;
std::shared_ptr<minja::TemplateNode> template_root_;
std::string tool_call_example_;
std::string try_raw_render(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
{
try {
chat_template_inputs inputs;
inputs.messages = messages;
inputs.tools = tools;
inputs.add_generation_prompt = add_generation_prompt;
inputs.extra_context = extra_context;
// Use fixed date for tests
inputs.now = std::chrono::system_clock::from_time_t(0);
chat_template_options opts;
opts.apply_polyfills = false;
auto prompt = apply(inputs, opts);
// fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
return prompt;
} catch (const std::exception & e) {
// fprintf(stderr, "try_raw_render error: %s\n", e.what());
return "";
}
}
public:
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
: source_(source), bos_token_(bos_token), eos_token_(eos_token)
{
template_root_ = minja::Parser::parse(source_, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
auto contains = [](const std::string & haystack, const std::string & needle) {
return haystack.find(needle) != std::string::npos;
};
const std::string user_needle = "<User Needle>";
const std::string sys_needle = "<System Needle>";
const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}};
const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}};
caps_.requires_typed_content =
!contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle)
&& contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle);
const auto dummy_user_msg = caps_.requires_typed_content
? dummy_typed_user_msg
: dummy_str_user_msg;
const json needle_system_msg = {
{"role", "system"},
{"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)},
};
caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle);
auto out = try_raw_render(json::array({
dummy_user_msg
}), json::array({
{
{"name", "some_tool"},
{"type", "function"},
{"function", {
{"name", "some_tool"},
{"description", "Some tool."},
{"parameters", {
{"type", "object"},
{"properties", {
{"arg", {
{"type", "string"},
{"description", "Some argument."},
}},
}},
{"required", json::array({ "arg" })},
}},
}},
},
}), false);
caps_.supports_tools = contains(out, "some_tool");
const auto render_with_content = [&](const json & content) {
const json assistant_msg {{"role", "assistant"}, {"content", content}};
// Render two assistant messages as some templates like QwQ-32B are handling
// the content differently depending on whether it's the last message or not
// (to remove the <think> tag in all but the last message).
return try_raw_render(json::array({dummy_user_msg, assistant_msg, dummy_user_msg, assistant_msg}), {}, false);
};
auto out_empty = render_with_content("");
auto out_null = render_with_content(json());
caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle);
json j_null;
auto make_tool_calls_msg = [&](const json & tool_calls) {
return json {
{"role", "assistant"},
{"content", caps_.requires_non_null_content? "" : j_null},
{"tool_calls", tool_calls},
};
};
auto make_tool_call = [](const std::string & tool_name, const json & arguments) {
return json {
{"id", "call_1___"},
{"type", "function"},
{"function", {
{"arguments", arguments},
{"name", tool_name},
}},
};
};
const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
// Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
}), {}, false);
auto tool_call_renders_str_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
}), {}, false);
auto tool_call_renders_obj_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
if (caps_.supports_tool_calls) {
auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump());
auto tc1 = make_tool_call("test_tool1", dummy_args);
auto tc2 = make_tool_call("test_tool2", dummy_args);
auto out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({tc1, tc2})),
}), {}, false);
caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2");
out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({tc1})),
{
{"role", "tool"},
{"name", "test_tool1"},
{"content", "Some response!"},
{"tool_call_id", "call_911_"},
}
}), {}, false);
caps_.supports_tool_responses = contains(out, "Some response!");
caps_.supports_tool_call_id = contains(out, "call_911_");
}
try {
if (!caps_.supports_tools) {
const json user_msg {
{"role", "user"},
{"content", "Hey"},
};
const json args {
{"arg1", "some_value"},
};
const json tool_call_msg {
{"role", "assistant"},
{"content", caps_.requires_non_null_content ? "" : j_null},
{"tool_calls", json::array({
{
// TODO: detect if requires numerical id or fixed length == 6 like Nemo
{"id", "call_1___"},
{"type", "function"},
{"function", {
{"name", "tool_name"},
{"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))},
}},
},
})},
};
std::string prefix, full;
{
chat_template_inputs inputs;
inputs.messages = json::array({user_msg});
inputs.add_generation_prompt = true;
prefix = apply(inputs);
}
{
chat_template_inputs inputs;
inputs.messages = json::array({user_msg, tool_call_msg});
inputs.add_generation_prompt = false;
full = apply(inputs);
}
auto eos_pos_last = full.rfind(eos_token_);
if (eos_pos_last == prefix.size() - eos_token_.size() ||
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
full = full.substr(0, eos_pos_last);
}
size_t common_prefix_length = 0;
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
if (prefix[i] != full[i]) {
break;
}
if (prefix[i] == '<') {
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
// but it removes thinking tags for past messages.
// The prefix and full strings diverge at <think> vs. <tool▁calls▁begin>, we avoid consuming the leading <.
continue;
}
common_prefix_length = i + 1;
}
auto example = full.substr(common_prefix_length);
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
} else {
tool_call_example_ = example;
}
}
} catch (const std::exception & e) {
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
}
}
const std::string & source() const { return source_; }
const std::string & bos_token() const { return bos_token_; }
const std::string & eos_token() const { return eos_token_; }
const chat_template_caps & original_caps() const { return caps_; }
// Deprecated, please use the form with chat_template_inputs and chat_template_options
std::string apply(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
bool apply_polyfills = true)
{
fprintf(stderr, "[%s] Deprecated!\n", __func__);
chat_template_inputs inputs;
inputs.messages = messages;
inputs.tools = tools;
inputs.add_generation_prompt = add_generation_prompt;
inputs.extra_context = extra_context;
inputs.now = std::chrono::system_clock::now();
chat_template_options opts;
opts.apply_polyfills = apply_polyfills;
return apply(inputs, opts);
}
std::string apply(
const chat_template_inputs & inputs,
const chat_template_options & opts = chat_template_options()) const
{
json actual_messages;
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto has_tool_calls = false;
auto has_tool_responses = false;
auto has_string_content = false;
for (const auto & message : inputs.messages) {
if (message.contains("tool_calls") && !message["tool_calls"].is_null()) {
has_tool_calls = true;
}
if (message.contains("role") && message["role"] == "tool") {
has_tool_responses = true;
}
if (message.contains("content") && message["content"].is_string()) {
has_string_content = true;
}
}
auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role;
auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools;
auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples;
auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls;
auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses;
auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments;
auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content;
auto needs_polyfills = opts.apply_polyfills && (false
|| polyfill_system_role
|| polyfill_tools
|| polyfill_tool_calls
|| polyfill_tool_responses
|| polyfill_object_arguments
|| polyfill_typed_content
);
if (needs_polyfills) {
actual_messages = json::array();
auto add_message = [&](const json & msg) {
if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
actual_messages.push_back({
{"role", msg.at("role")},
{"content", {{
{"type", "text"},
{"text", msg.at("content")},
}}},
});
} else {
actual_messages.push_back(msg);
}
};
std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
add_message({
{"role", "user"},
{"content", pending_system},
});
pending_system.clear();
}
};
json adjusted_messages;
if (polyfill_tools) {
adjusted_messages = add_system(inputs.messages,
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
} else {
adjusted_messages = inputs.messages;
}
for (const auto & message_ : adjusted_messages) {
auto message = message_;
if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) {
throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump());
}
std::string role = message.at("role");
if (message.contains("tool_calls")) {
if (polyfill_object_arguments || polyfill_tool_calls) {
for (auto & tool_call : message.at("tool_calls")) {
if (tool_call["type"] == "function") {
auto & function = tool_call.at("function");
auto & arguments = function.at("arguments");
if (arguments.is_string()) {
try {
arguments = json::parse(arguments.get<std::string>());
} catch (const std::exception & ecvt) {
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
}
}
}
}
}
if (polyfill_tool_calls) {
auto tool_calls = json::array();
for (const auto & tool_call : message.at("tool_calls")) {
if (tool_call.at("type") != "function") {
continue;
}
const auto & function = tool_call.at("function");
auto tc = json {
{"name", function.at("name")},
{"arguments", function.at("arguments")},
};
if (tool_call.contains("id")) {
tc["id"] = tool_call["id"];
}
tool_calls.push_back(tc);
}
auto obj = json {
{"tool_calls", tool_calls},
};
if (message.contains("content")) {
auto content = message.at("content");
if (!content.is_null() && !content.empty()) {
obj["content"] = content;
}
}
message["content"] = obj.dump(2);
message.erase("tool_calls");
}
}
if (polyfill_tool_responses && role == "tool") {
message["role"] = "user";
auto obj = json {
{"tool_response", json::object()},
};
if (message.contains("name")) {
obj["tool_response"]["tool"] = message.at("name");
}
obj["tool_response"]["content"] = message.at("content");
if (message.contains("tool_call_id")) {
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
}
message["content"] = obj.dump(2);
message.erase("name");
}
if (!message["content"].is_null() && polyfill_system_role) {
std::string content = message.at("content");
if (role == "system") {
if (!pending_system.empty()) pending_system += "\n";
pending_system += content;
continue;
} else {
if (role == "user") {
if (!pending_system.empty()) {
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
pending_system.clear();
}
} else {
flush_sys();
}
}
}
add_message(message);
}
flush_sys();
} else {
actual_messages = inputs.messages;
}
auto context = minja::Context::make(json({
{"messages", actual_messages},
{"add_generation_prompt", inputs.add_generation_prompt},
}));
context->set("bos_token", opts.use_bos_token ? bos_token_ : "");
context->set("eos_token", opts.use_eos_token ? eos_token_ : "");
if (opts.define_strftime_now) {
auto now = inputs.now;
context->set("strftime_now", Value::callable([now](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) {
args.expectArgs("strftime_now", {1, 1}, {0, 0});
auto format = args.args[0].get<std::string>();
auto time = std::chrono::system_clock::to_time_t(now);
auto local_time = *std::localtime(&time);
std::ostringstream ss;
ss << std::put_time(&local_time, format.c_str());
return ss.str();
}));
}
if (!inputs.tools.is_null()) {
context->set("tools", minja::Value(inputs.tools));
}
if (!inputs.extra_context.is_null()) {
for (auto & kv : inputs.extra_context.items()) {
context->set(kv.key(), minja::Value(kv.value()));
}
}
auto ret = template_root_->render(context);
// fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str());
// fprintf(stderr, "apply: %s\n\n", ret.c_str());
return ret;
}
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
json messages_with_system = messages;
if (!messages_with_system.empty() && messages_with_system[0].at("role") == "system") {
std::string existing_system = messages_with_system.at(0).at("content");
messages_with_system[0] = json {
{"role", "system"},
{"content", existing_system + "\n\n" + system_prompt},
};
} else {
messages_with_system.insert(messages_with_system.begin(), json {
{"role", "system"},
{"content", system_prompt},
});
}
return messages_with_system;
}
};
} // namespace minja

View File

@@ -8,14 +8,27 @@
// SPDX-License-Identifier: MIT
#pragma once
#include <algorithm>
#include <cctype>
#include <cstddef>
#include <cstdint>
#include <cmath>
#include <exception>
#include <functional>
#include <iostream>
#include <string>
#include <vector>
#include <regex>
#include <iterator>
#include <limits>
#include <map>
#include <memory>
#include <stdexcept>
#include <regex>
#include <sstream>
#include <string>
#include <stdexcept>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <json.hpp>
using json = nlohmann::ordered_json;
@@ -1233,7 +1246,7 @@ public:
}
return result;
} else if (target_value.is_array()) {
} else if (target_value.is_array()) {
auto result = Value::array();
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
result.push_back(target_value.at(i));
@@ -1278,6 +1291,12 @@ public:
}
};
static bool in(const Value & value, const Value & container) {
return (((container.is_array() || container.is_object()) && container.contains(value)) ||
(value.is_string() && container.is_string() &&
container.to_str().find(value.to_str()) != std::string::npos));
}
class BinaryOpExpr : public Expression {
public:
enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot };
@@ -1342,13 +1361,8 @@ public:
case Op::Gt: return l > r;
case Op::Le: return l <= r;
case Op::Ge: return l >= r;
case Op::In: return (((r.is_array() || r.is_object()) && r.contains(l)) ||
(l.is_string() && r.is_string() &&
r.to_str().find(l.to_str()) != std::string::npos));
case Op::NotIn:
return !(((r.is_array() || r.is_object()) && r.contains(l)) ||
(l.is_string() && r.is_string() &&
r.to_str().find(l.to_str()) != std::string::npos));
case Op::In: return in(l, r);
case Op::NotIn: return !in(l, r);
default: break;
}
throw std::runtime_error("Unknown binary operator");
@@ -1487,6 +1501,13 @@ public:
} else if (method->get_name() == "pop") {
vargs.expectArgs("pop method", {1, 1}, {0, 0});
return obj.pop(vargs.args[0]);
} else if (method->get_name() == "keys") {
vargs.expectArgs("keys method", {0, 0}, {0, 0});
auto result = Value::array();
for (const auto& key : obj.keys()) {
result.push_back(Value(key));
}
return result;
} else if (method->get_name() == "get") {
vargs.expectArgs("get method", {1, 2}, {0, 0});
auto key = vargs.args[0];
@@ -1528,6 +1549,16 @@ public:
} else if (method->get_name() == "capitalize") {
vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
return Value(capitalize(str));
} else if (method->get_name() == "upper") {
vargs.expectArgs("upper method", {0, 0}, {0, 0});
auto result = str;
std::transform(result.begin(), result.end(), result.begin(), ::toupper);
return Value(result);
} else if (method->get_name() == "lower") {
vargs.expectArgs("lower method", {0, 0}, {0, 0});
auto result = str;
std::transform(result.begin(), result.end(), result.begin(), ::tolower);
return Value(result);
} else if (method->get_name() == "endswith") {
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
auto suffix = vargs.args[0].get<std::string>();
@@ -1544,20 +1575,6 @@ public:
else res[i] = std::tolower(res[i]);
}
return res;
} else if (method->get_name() == "replace") {
vargs.expectArgs("replace method", {2, 3}, {0, 0});
auto before = vargs.args[0].get<std::string>();
auto after = vargs.args[1].get<std::string>();
auto count = vargs.args.size() == 3 ? vargs.args[2].get<int64_t>()
: str.length();
size_t start_pos = 0;
while ((start_pos = str.find(before, start_pos)) != std::string::npos &&
count-- > 0) {
str.replace(start_pos, before.length(), after);
start_pos += after.length();
}
return str;
}
}
throw std::runtime_error("Unknown method: " + method->get_name());
@@ -2117,38 +2134,9 @@ private:
std::shared_ptr<Expression> start, end, step;
bool has_first_colon = false, has_second_colon = false;
if (!peekSymbols({ ":" })) {
start = parseExpression();
}
if (!consumeToken(":").empty()) {
has_first_colon = true;
@@ -2162,8 +2150,8 @@ private:
}
}
}
if ((has_first_colon || has_second_colon)) {
if ((has_first_colon || has_second_colon) && (start || end || step)) {
index = std::make_shared<SliceExpr>(slice_loc, std::move(start), std::move(end), std::move(step));
} else {
index = std::move(start);
@@ -2663,17 +2651,13 @@ inline std::shared_ptr<Context> Context::builtins() {
auto items = Value::array();
if (args.contains("object")) {
auto & obj = args.at("object");
if (obj.is_string()) {
auto json_obj = json::parse(obj.get<std::string>());
for (const auto & kv : json_obj.items()) {
items.push_back(Value::array({kv.key(), kv.value()}));
if (!obj.is_object()) {
throw std::runtime_error("Can only get item pairs from a mapping");
}
} else if (!obj.is_null()) {
for (auto & key : obj.keys()) {
items.push_back(Value::array({key, obj.at(key)}));
}
}
}
return items;
}));
globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr<Context> &, Value & args) {
@@ -2686,14 +2670,6 @@ inline std::shared_ptr<Context> Context::builtins() {
auto & text = args.at("text");
return text.is_null() ? text : Value(strip(text.get<std::string>()));
}));
auto char_transform_function = [](const std::string & name, const std::function<char(char)> & fn) {
return simple_function(name, { "text" }, [=](const std::shared_ptr<Context> &, Value & args) {
auto text = args.at("text");
@@ -2807,6 +2783,9 @@ inline std::shared_ptr<Context> Context::builtins() {
if (!items.is_array()) throw std::runtime_error("object is not iterable");
return items;
}));
globals.set("in", simple_function("in", { "item", "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
return in(args.at("item"), args.at("items"));
}));
globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
auto & items = args.at("items");
if (!items.is_array()) throw std::runtime_error("object is not iterable");
@@ -2846,16 +2825,10 @@ inline std::shared_ptr<Context> Context::builtins() {
if (filter_fn.is_null()) {
throw std::runtime_error("Undefined filter: " + args.args[1].dump());
}
auto filter_args = Value::array();
for (size_t i = 2, n = args.args.size(); i < n; i++) {
filter_args.push_back(args.args[i]);
}
auto filter = make_filter(filter_fn, filter_args);
@@ -2942,8 +2915,6 @@ inline std::shared_ptr<Context> Context::builtins() {
}
test_args.kwargs = args.kwargs;
}
auto res = Value::array();
for (size_t i = 0, n = items.size(); i < n; i++) {
@@ -2957,10 +2928,7 @@ inline std::shared_ptr<Context> Context::builtins() {
} else {
res.push_back(attr);
}
}
return res;
});
};
@@ -2978,7 +2946,6 @@ inline std::shared_ptr<Context> Context::builtins() {
auto v = arg.get<int64_t>();
startEndStep[i] = v;
param_set[i] = true;
}
}
for (auto & [name, value] : args.kwargs) {

View File

@@ -118,7 +118,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
if (it == end) {
throw std::runtime_error("Unmatched '{' in pattern");
}
auto parts = string_split(std::string(start, it), ',');
auto parts = string_split(std::string(start, it), ",");
++it;
if (parts.size() > 2) {
throw std::runtime_error("Invalid repetition range in pattern");

View File

@@ -9,8 +9,23 @@ enum common_regex_match_type {
COMMON_REGEX_MATCH_TYPE_FULL,
};
// Include full definition of common_string_range
#include "chat.h"
struct common_string_range {
size_t begin;
size_t end;
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
if (begin > end) {
throw std::runtime_error("Invalid range");
}
}
// prevent default ctor
common_string_range() = delete;
bool empty() const {
return begin == end;
}
bool operator==(const common_string_range & other) const {
return begin == other.begin && end == other.end;
}
};
struct common_regex_match {
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;

View File

@@ -1,7 +1,10 @@
#define LLAMA_API_INTERNAL
#include "sampling.h"
#include "llama-vocab.h"
#include "common.h"
#include <random>
#include "json.hpp"
using json = nlohmann::ordered_json;
struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();
@@ -9,10 +12,68 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo
result->params = params;
result->grammar = nullptr;
struct llama_grammar* grmr;
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
#else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
}
else {
std::vector<std::string> trigger_patterns;
std::vector<std::string> patterns_anywhere;
std::vector<llama_token> trigger_tokens;
for (const auto& trigger : params.grammar_triggers) {
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto& word = trigger.value;
patterns_anywhere.push_back(regex_escape(word));
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
{
patterns_anywhere.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{
trigger_patterns.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
{
const auto token = trigger.token;
trigger_tokens.push_back(token);
break;
}
default:
GGML_ASSERT(false && "unknown trigger type");
}
}
if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
}
std::vector<const char*> trigger_patterns_c;
trigger_patterns_c.reserve(trigger_patterns.size());
for (const auto& regex : trigger_patterns) {
trigger_patterns_c.push_back(regex.c_str());
}
grmr = params.grammar_lazy
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size())
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
// if there is a grammar, parse it
if (!params.grammar.empty()) {
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
if (result->parsed_grammar.success) {
// will be empty (default) if there are parse errors
if (result->parsed_grammar.rules.empty()) {
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
@@ -26,21 +87,15 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo
delete result;
return nullptr;
}
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
struct llama_grammar * grammar = llama_grammar_init(
grammar_rules.data(),
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
if (grammar == nullptr) {
if (grmr == nullptr) {
throw std::runtime_error("Failed to initialize llama_grammar");
}
result->grammar = grammar;
}
}
result->prev.resize(params.n_prev);
result->n_valid = 0;
}
result->grammar = grmr;
// init DRY
for (const auto& cnstr : params.samplers_sequence)
{
@@ -75,27 +130,71 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
delete ctx;
}
void llama_sampling_reset(llama_sampling_context * ctx) {
void llama_sampling_reset(const struct llama_vocab* vocab, llama_sampling_context * ctx) {
if (ctx->grammar != NULL) {
llama_grammar_free(ctx->grammar);
ctx->grammar = NULL;
}
if (!ctx->parsed_grammar.rules.empty()) {
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
struct llama_grammar * grammar = llama_grammar_init(
grammar_rules.data(),
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
if (grammar == nullptr) {
throw std::runtime_error("Failed to initialize llama_grammar");
}
ctx->grammar = grammar;
struct llama_grammar* grmr;
auto params = ctx->params;
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
#else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
}
else {
std::vector<std::string> trigger_patterns;
std::vector<std::string> patterns_anywhere;
std::vector<llama_token> trigger_tokens;
for (const auto& trigger : params.grammar_triggers) {
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto& word = trigger.value;
patterns_anywhere.push_back(regex_escape(word));
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
{
patterns_anywhere.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{
trigger_patterns.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
{
const auto token = trigger.token;
trigger_tokens.push_back(token);
break;
}
default:
GGML_ASSERT(false && "unknown trigger type");
}
}
if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
}
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
ctx->n_valid = 0;
std::vector<const char*> trigger_patterns_c;
trigger_patterns_c.reserve(trigger_patterns.size());
for (const auto& regex : trigger_patterns) {
trigger_patterns_c.push_back(regex.c_str());
}
grmr = params.grammar_lazy
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size())
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
}
ctx->grammar = grmr;
llama_sampler_dry_reset(ctx->smpl);
}
@@ -498,7 +597,10 @@ void llama_sampling_accept(
struct llama_context * ctx_main,
llama_token id,
bool apply_grammar) {
if (ctx_sampling->prev.size() > 0) {
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
}
ctx_sampling->prev.push_back(id);
if (ctx_sampling->grammar != NULL && apply_grammar) {
@@ -552,3 +654,29 @@ std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_samplin
return result;
}
template <>
json common_grammar_trigger::to_json() const {
json out{
{"type", (int)type},
{"value", value},
};
if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
out["token"] = (int)token;
}
return out;
}
template <>
common_grammar_trigger common_grammar_trigger::from_json(const json& in) {
common_grammar_trigger out;
out.type = (common_grammar_trigger_type)in.at("type").get<int>();
out.value = in.at("value").get<std::string>();
if (out.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
out.token = (llama_token)in.at("token").get<int>();
}
return out;
}

View File

@@ -1,9 +1,8 @@
#pragma once
#include "llama.h"
#include "grammar-parser.h"
#include <set>
#include <random>
#include <string>
#include <unordered_map>
@@ -22,6 +21,23 @@ enum class llama_sampler_type : char {
TEMPERATURE = 't'
};
enum common_grammar_trigger_type {
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
};
struct common_grammar_trigger {
common_grammar_trigger_type type;
std::string value;
llama_token token = LLAMA_TOKEN_NULL;
// T can only be nlohmann::ordered_json
template <class T> T to_json() const;
template <class T> static common_grammar_trigger from_json(const T& in);
};
// sampling parameters
typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember
@@ -67,8 +83,11 @@ typedef struct llama_sampling_params {
llama_sampler_type::TEMPERATURE
};
std::string grammar; // optional BNF-like grammar to constrain sampling
std::string grammar; // optional BNF-like grammar to constrain sampling
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
std::set<llama_token> preserved_tokens;
// Classifier-Free Guidance
// https://arxiv.org/abs/2306.17806
std::string cfg_negative_prompt; // string to help guidance
@@ -106,7 +125,7 @@ struct llama_sampling_context {
std::mt19937 rng;
};
#include "common.h"
// Create a new sampling context instance.
struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params);
@@ -116,7 +135,7 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
// Reset the sampler context
// - clear prev tokens
// - reset grammar
void llama_sampling_reset(llama_sampling_context * ctx);
void llama_sampling_reset(const struct llama_vocab* vocab, llama_sampling_context * ctx);
// Set the sampler seed
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
@@ -186,3 +205,6 @@ llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_con
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft);
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft);
llama_grammar* llama_sampler_init_llg(const llama_vocab* vocab,
const char* grammar_kind, const char* grammar_data);

View File

@@ -3,7 +3,7 @@
#include "common.h"
#include "sampling.h"
#include "llama-impl.h"
#include "llama-vocab.h"
#include <cstring>
#include <algorithm>
#include <map>
@@ -302,7 +302,7 @@ std::vector<llama_token> llama_speculative_gen_draft(
llama_decode(ctx_dft, batch);
llama_sampling_reset(smpl);
llama_sampling_reset(llama_get_vocab(ctx_dft), smpl);
// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_draft; ++i) {