Tool calls support from mainline (#723)

* Tool calls support from mainline

* update cmake

* revert api for /completions

* Fix broken thinking process for gpt-oss

* add missing args and fix webui bugs

* add missing args and fix webui bugs2

* Fix reasoning format error

* add usage

* change default post_sampling_probs to true

* add back generated_text

* Remove server endpoints tests

* add log

* Chat fixes

* Remove logs

* webui: revert extra handling of thinking process

---------

Co-authored-by: firecoperana <firecoperana>
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
firecoperana
2025-09-01 00:38:49 -05:00
committed by GitHub
parent 8de297b795
commit d7882c3cf8
87 changed files with 13581 additions and 2224 deletions

View File

@@ -84,56 +84,60 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2 ARGS ${CMAKE
llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
# build test-tokenizer-1-bpe target once and add many tests
add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp)
target_link_libraries(test-tokenizer-1-bpe PRIVATE common)
install(TARGETS test-tokenizer-1-bpe RUNTIME)
if (LLAMA_LLGUIDANCE)
llama_target_and_test(test-grammar-llguidance.cpp ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf)
endif ()
# TODO: disabled due to slowness
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-aquila ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-2.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-neox ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf --ignore-merges)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-mpt ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
if (NOT WIN32)
# these tests are disabled on Windows because they use internal functions not exported with LLAMA_API
#llama_target_and_test(test-sampling.cpp)
#llama_target_and_test(test-grammar-parser.cpp)
#llama_target_and_test(test-grammar-integration.cpp)
#llama_target_and_test(test-llama-grammar.cpp)
#llama_target_and_test(test-chat.cpp)
# TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
llama_target_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..)
target_include_directories(test-json-schema-to-grammar PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../examples/server)
endif()
# build test-tokenizer-1-spm target once and add many tests
add_executable(test-tokenizer-1-spm test-tokenizer-1-spm.cpp)
target_link_libraries(test-tokenizer-1-spm PRIVATE common)
install(TARGETS test-tokenizer-1-spm RUNTIME)
llama_test(test-tokenizer-1-spm NAME test-tokenizer-1-llama-spm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-spm.gguf)
#llama_test(test-tokenizer-1-spm NAME test-tokenizer-1-baichuan ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
# build test-tokenizer-1-bpe target once and add many tests
add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp)
target_link_libraries(test-tokenizer-1-bpe PRIVATE common)
install(TARGETS test-tokenizer-1-bpe RUNTIME)
# llama_target_and_test(test-double-float.cpp) # SLOW
llama_target_and_test(test-quantize-fns.cpp)
llama_target_and_test(test-quantize-perf.cpp)
llama_target_and_test(test-sampling.cpp)
# TODO: disabled due to slowness
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-aquila ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-2.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-neox ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf --ignore-merges)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-mpt ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
# build test-tokenizer-1-spm target once and add many tests
add_executable(test-tokenizer-1-spm test-tokenizer-1-spm.cpp)
target_link_libraries(test-tokenizer-1-spm PRIVATE common)
install(TARGETS test-tokenizer-1-spm RUNTIME)
llama_test(test-tokenizer-1-spm NAME test-tokenizer-1-llama-spm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-spm.gguf)
#llama_test(test-tokenizer-1-spm NAME test-tokenizer-1-baichuan ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
# llama_target_and_test(test-double-float.cpp) # SLOW
endif()
llama_target_and_test(test-chat-parser.cpp)
llama_target_and_test(test-chat-template.cpp)
llama_target_and_test(test-json-partial.cpp)
llama_target_and_test(test-regex-partial.cpp)
llama_target_and_test(test-grammar-parser.cpp)
llama_target_and_test(test-llama-grammar.cpp)
llama_target_and_test(test-grammar-integration.cpp)
llama_target_and_test(test-grad0.cpp)
# llama_target_and_test(test-opt.cpp) # SLOW
llama_target_and_test(test-backend-ops.cpp)
llama_target_and_test(test-rope.cpp)
llama_target_and_test(test-model-load-cancel.cpp LABEL "model")
llama_target_and_test(test-autorelease.cpp LABEL "model")
# TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
llama_target_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..)
target_include_directories(test-json-schema-to-grammar PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../examples/server)
endif()
# Function calling parser tests
llama_target_and_test(test-function-calls.cpp)
target_include_directories(test-function-calls PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../examples/server)
# dummy executable - not installed
get_filename_component(TEST_TARGET test-c.c NAME_WE)

355
tests/test-chat-parser.cpp Normal file
View File

@@ -0,0 +1,355 @@
// Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
//
// Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
// e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
//
// cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
//
#include <exception>
#include <iostream>
#include <json.hpp>
#include <string>
#include "chat-parser.h"
#include "common.h"
#include "log.h"
#include "regex-partial.h"
using json = nlohmann::ordered_json;
template <class T>
static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << "Actual: " << actual << std::endl;
std::cerr << std::flush;
throw std::runtime_error("Test failed");
}
}
static void assert_equals(const char * expected, const std::string & actual) {
return assert_equals<std::string>(expected, actual);
}
static void assert_throws(const std::function<void()> & fn, const std::string & expected_exception_pattern = "") {
try {
fn();
} catch (const std::exception & e) {
if (expected_exception_pattern.empty()) {
return;
}
std::regex expected_exception_regex(expected_exception_pattern);
std::string actual_message = e.what();
if (std::regex_search(actual_message, expected_exception_regex)) {
return;
}
throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")");
throw std::runtime_error("Exception of unexpected type: " + std::string(e.what()));
}
throw std::runtime_error("Exception was expected but not thrown");
}
static void test_reasoning() {
{
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
});
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals("<tnk>Cogito</tnk>Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
});
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
assert_equals("Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
});
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals("Cogito</tnk>Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ true,
});
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
assert_equals("Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ true,
/* .thinking_forced_open = */ true,
});
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals("<think>Cogito</think>", builder.result().content);
assert_equals("Ergo sum", builder.consume_rest());
}
}
static void test_regex() {
auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") {
common_chat_msg_parser builder(input, /* is_partial= */ false, {});
assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern);
};
test_throws("Hello, world!", "abc", "^abc$");
test_throws("Hello, world!", "e", "^e$");
{
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
builder.consume_regex(common_regex("Hello"));
assert_equals(", world!", builder.consume_rest());
}
{
// When in non partial mode, we can say whether the regex was consumed or not.
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value());
}
{
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
auto res = builder.try_consume_regex(common_regex("H(el)l(?:o, world!)?"));
assert_equals(true, res.has_value());
// Verify captures
assert_equals<size_t>(2, res->groups.size());
assert_equals("Hell", builder.str(res->groups[0]));
assert_equals("el", builder.str(res->groups[1]));
// Verify position is after the match
assert_equals<size_t>(4, builder.pos());
assert_equals("o,", builder.consume_rest());
}
{
// But in partial mode, we have a partial final match / can't decide, so we throw a partial exception.
common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {});
assert_throws([&]() {
builder.try_consume_regex(common_regex("Hello, world!"));
}, "^Hello, world!$");
}
// Now regardless of the mode, we can tell these aren't a match.
for (const auto is_partial : {false, true}) {
common_chat_msg_parser builder("Hello,", is_partial, {});
assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value());
}
for (const auto is_partial : {false, true}) {
common_chat_msg_parser builder("Hello,", is_partial, {});
assert_equals(false, builder.try_consume_literal("Oh"));
}
}
const std::vector<std::string> barely_healable_jsons = {
"{",
"{\"",
"{\"\\",
"{\"n",
"{\"name\"",
"{\"name\":",
"{\"name\":\"",
"{\"name\":\"\\",
"{\"name\":\"python",
"{\"name\":\"python\\",
"{\",",
"{\":",
"{\"[",
"{\"]",
"{\"{",
"{\"}",
"{\"1",
"{\"name\":\",",
"{\"name\":\":",
"{\"name\":\"[",
"{\"name\":\"]",
"{\"name\":\"{",
"{\"name\":\"}",
"{\"name\":\"1",
};
static void test(const std::string & input, bool is_partial, const std::vector<std::vector<std::string>> & args_paths, const std::vector<std::vector<std::string>> & content_paths, const std::string & expected) {
common_chat_msg_parser builder(input, is_partial, {});
auto js = builder.try_consume_json_with_dumped_args(args_paths, content_paths);
assert_equals(true, js.has_value());
assert_equals(is_partial, js->is_partial);
assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get<std::string>() : js->value.dump());
}
static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) {
common_chat_msg_parser builder(input, parse_as_partial, {});
auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {});
assert_equals(true, js.has_value());
assert_equals(is_partial, js->is_partial);
assert_equals(expected, js->value.dump());
}
static void test_json_with_dumped_args_no_args() {
// Normal JSON, nothing to heal, nothing to dump
test("{\"name\": \"python\"}", false, {}, {}, "{\"name\":\"python\"}");
// Full json is args
test("{\"name\": \"python\"}", false, {{}}, {}, "{\"name\":\"python\"}");
// If the arguments are further down, don't heal partial content.
for (const auto & src : barely_healable_jsons) {
test(src, true, {{"arguments"}}, {}, "{}");
}
// But heal content that isn't partial.
test("{\"name\": \"python\"", true, {{"arguments"}}, {}, "{\"name\":\"python\"}");
}
static void test_json_with_dumped_args() {
// Partial content.
test("{\"content\": \"t", true, {}, {{"content"}}, "{\"content\":\"t\"}");
test("{\"content\": \"", true, {}, {{"content"}}, "{\"content\":\"\"}");
test("{\"content\": ", true, {}, {{"content"}}, "{}");
// If the entire JSON is the arguments, healing it them dumping it produces the same output as the input (just reformatted).
test("{\"name\": \"python", true, {{}}, {}, "{\"name\":\"python");
for (const auto & src : barely_healable_jsons) {
test(src, true, {{}}, {}, src);
}
// Full JSON w/ args
for (auto parse_as_partial : {true, false}) {
test_with_args(
R"({"name": "python", "args": {"arg1": 1}})",
R"({"name":"python","args":"{\"arg1\":1}"})",
parse_as_partial,
/* is_partial= */ false
);
}
// Partial JSON w/ partial args
test_with_args(
R"({"foo": "bar", "args": {")",
R"({"foo":"bar","args":"{\""})"
);
// Partial args broken in object key
test_with_args(
R"({"foo": "bar", "args": {"ar)",
R"({"foo":"bar","args":"{\"ar"})"
);
// Partial args broken after object key
test_with_args(
R"({"foo": "bar", "args": {"arg1")",
R"({"foo":"bar","args":"{\"arg1\""})"
);
// Partial args broken before object value
test_with_args(
R"({"foo": "bar", "args": {"arg1":)",
R"({"foo":"bar","args":"{\"arg1\":"})"
);
// Partial args broken before object value (space)
test_with_args(
R"({"foo": "bar", "args": {"arg1": )",
R"({"foo":"bar","args":"{\"arg1\":"})"
);
// Partial args broken in object value that may not be complete (int)
test_with_args(
R"({"foo": "bar", "args": {"arg1": 1)",
R"({"foo":"bar","args":"{\"arg1\":"})"
);
// Partial args broken in object value that is complete (int)
test_with_args(
R"({"foo": "bar", "args": {"arg1": 1 )",
R"({"foo":"bar","args":"{\"arg1\":1"})"
);
// Partial args broken in object value that is incomplete (string)
test_with_args(
R"({"foo": "bar", "args": {"arg1": ")",
R"({"foo":"bar","args":"{\"arg1\":\""})"
);
// Partial args broken in object value that is complete (string)
test_with_args(
R"({"foo": "bar", "args": {"arg1": "1")",
R"({"foo":"bar","args":"{\"arg1\":\"1\""})"
);
// Partial args broken on array opening
test_with_args(
R"({"foo": "bar", "args": [)",
R"({"foo":"bar","args":"["})"
);
// Partial args broken on array value that is incomplete (int)
test_with_args(
R"({"foo": "bar", "args": [1)",
R"({"foo":"bar","args":"["})"
);
// Partial args broken on array value that is complete (int)
test_with_args(
R"({"foo": "bar", "args": [1 )",
R"({"foo":"bar","args":"[1"})"
);
// Partial args broken on array value that is complete (string)
test_with_args(
R"({"foo": "bar", "args": ["1")",
R"({"foo":"bar","args":"[\"1\""})"
);
// Partial args broken after array value
test_with_args(
R"({"foo": "bar", "args": [1,)",
R"({"foo":"bar","args":"[1,"})"
);
// Partial args broken on nested array
test_with_args(
R"({"foo": "bar", "args": {"arg1": [)",
R"({"foo":"bar","args":"{\"arg1\":["})"
);
}
static void test_positions() {
{
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
assert_equals<size_t>(0, builder.pos());
assert_throws([&]() { builder.move_to(100); });
assert_equals<size_t>(0, builder.pos());
assert_throws([&]() { builder.move_back(1); });
assert_equals<size_t>(0, builder.pos());
builder.move_to(8);
assert_equals<size_t>(8, builder.pos());
builder.move_back(1);
assert_equals<size_t>(7, builder.pos());
assert_equals("world!", builder.consume_rest());
builder.move_to(0);
assert_equals<size_t>(0, builder.pos());
assert_throws([&]() { builder.finish(); });
assert_equals<size_t>(0, builder.pos());
builder.move_to(builder.input().size());
builder.finish();
}
{
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ true, {});
builder.move_to(builder.input().size());
assert_equals<size_t>(builder.input().size(), builder.pos());
builder.finish();
}
}
int main() {
test_positions();
test_json_with_dumped_args_no_args();
test_json_with_dumped_args();
test_reasoning();
test_regex();
std::cout << "All tests passed!\n";
return 0;
}

View File

@@ -7,7 +7,9 @@
#include "llama.h"
#include "common.h"
#include "chat-template.hpp"
#include "minja/chat-template.hpp"
#include "minja/minja.hpp"
#include "chat.h"
static std::string normalize_newlines(const std::string& s) {
#ifdef _WIN32
@@ -150,12 +152,12 @@ int main(void) {
// test llama_chat_format_single for system message
printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
std::vector<llama_chat_msg> chat2;
llama_chat_msg sys_msg{"system", "You are a helpful assistant"};
std::vector<common_chat_msg> chat2;
common_chat_msg sys_msg{"system", "You are a helpful assistant"};
auto fmt_sys = [&](std::string tmpl_str) {
minja::chat_template tmpl(tmpl_str, "", "");
auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str);
auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false);
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
return output;
};
@@ -170,11 +172,11 @@ int main(void) {
chat2.push_back({"system", "You are a helpful assistant"});
chat2.push_back({"user", "Hello"});
chat2.push_back({"assistant", "I am assistant"});
llama_chat_msg new_msg{"user", "How are you"};
common_chat_msg new_msg{"user", "How are you"};
auto fmt_single = [&](std::string tmpl_str) {
minja::chat_template tmpl(tmpl_str, "", "");
auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true, /* use_jinja= */ false);
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str);
auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false);
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n");
return output;

1707
tests/test-chat.cpp Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -149,7 +149,7 @@ static void test_grammar(const std::string & test_desc, const std::string & gram
test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
}
static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str)), passing_strings, failing_strings);
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str), true), passing_strings, failing_strings);
}
static void test_simple_grammar() {

File diff suppressed because it is too large Load Diff

237
tests/test-json-partial.cpp Normal file
View File

@@ -0,0 +1,237 @@
#include "common.h"
#include "json-partial.h"
#include <exception>
#include <iostream>
#include <stdexcept>
template <class T> static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << "Actual: " << actual << std::endl;
std::cerr << std::flush;
throw std::runtime_error("Test failed");
}
}
static void test_json_healing() {
auto parse = [](const std::string & str) {
std::cerr << "# Parsing: " << str << '\n';
std::string::const_iterator it = str.begin();
const auto end = str.end();
common_json out;
std::string healing_marker = "$llama.cpp.json$";
if (common_json_parse(it, end, healing_marker, out)) {
auto dump = out.json.dump();
std::cerr << "Parsed: " << dump << '\n';
std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n';
std::string result;
if (!out.healing_marker.json_dump_marker.empty()) {
auto i = dump.find(out.healing_marker.json_dump_marker);
if (i == std::string::npos) {
throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")");
}
result = dump.substr(0, i);
} else {
result = dump;
}
std::cerr << "Result: " << result << '\n';
if (string_starts_with(str, result)) {
std::cerr << "Failure!\n";
}
// return dump;
} else {
throw std::runtime_error("Failed to parse: " + str);
}
};
auto parse_all = [&](const std::string & str) {
for (size_t i = 1; i < str.size(); i++) {
parse(str.substr(0, i));
}
};
parse_all("{\"a\": \"b\"}");
parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}");
parse_all("[{\"a\": \"b\"}]");
auto test = [&](const std::vector<std::string> & inputs, const std::string & expected, const std::string & expected_marker) {
for (const auto & input : inputs) {
common_json out;
assert_equals(true, common_json_parse(input, "$foo", out));
assert_equals<std::string>(expected, out.json.dump());
assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
}
};
// No healing needed:
test(
{
R"([{"a":"b"}, "y"])",
},
R"([{"a":"b"},"y"])",
""
);
// Partial literals can't be healed:
test(
{
R"([1)",
R"([tru)",
R"([n)",
R"([nul)",
R"([23.2)",
},
R"(["$foo"])",
R"("$foo)"
);
test(
{
R"({"a": 1)",
R"({"a": tru)",
R"({"a": n)",
R"({"a": nul)",
R"({"a": 23.2)",
},
R"({"a":"$foo"})",
R"("$foo)"
);
test(
{
R"({)",
},
R"({"$foo":1})",
R"("$foo)"
);
test(
{
R"([)",
},
R"(["$foo"])",
R"("$foo)"
);
// Healing right after a full literal
test(
{
R"(1 )",
},
R"(1)",
""
);
test(
{
R"(true)",
R"(true )",
},
R"(true)",
""
);
test(
{
R"(null)",
R"(null )",
},
R"(null)",
""
);
test(
{
R"([1 )",
},
R"([1,"$foo"])",
R"(,"$foo)"
);
test(
{
R"([{})",
R"([{} )",
},
R"([{},"$foo"])",
R"(,"$foo)"
);
test(
{
R"([true)",
},
// TODO: detect the true/false/null literal was complete
R"(["$foo"])",
R"("$foo)"
);
test(
{
R"([true )",
},
R"([true,"$foo"])",
R"(,"$foo)"
);
test(
{
R"([true,)",
},
R"([true,"$foo"])",
R"("$foo)"
);
// Test nesting
test(
{
R"([{"a": [{"b": [{)",
},
R"([{"a":[{"b":[{"$foo":1}]}]}])",
R"("$foo)"
);
test(
{
R"([{"a": [{"b": [)",
},
R"([{"a":[{"b":["$foo"]}]}])",
R"("$foo)"
);
test(
{
R"([{"a": "b"})",
R"([{"a": "b"} )",
},
R"([{"a":"b"},"$foo"])",
R"(,"$foo)"
);
test(
{
R"([{"a": "b"},)",
R"([{"a": "b"}, )",
},
R"([{"a":"b"},"$foo"])",
R"("$foo)"
);
test(
{
R"({ "code)",
},
R"({"code$foo":1})",
R"($foo)"
);
test(
{
R"({ "code\)",
},
R"({"code\\$foo":1})",
R"(\$foo)"
);
test(
{
R"({ "code")",
},
R"({"code":"$foo"})",
R"(:"$foo)"
);
test(
{
R"({ "key")",
},
R"({"key":"$foo"})",
R"(:"$foo)"
);
}
int main() {
test_json_healing();
std::cerr << "All tests passed.\n";
return 0;
}

View File

@@ -1231,7 +1231,7 @@ int main() {
test_all("C++", [](const TestCase & tc) {
try {
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema)));
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true));
tc.verify_status(SUCCESS);
} catch (const std::runtime_error & ex) {
fprintf(stderr, "Error: %s\n", ex.what());

View File

@@ -0,0 +1,288 @@
// Tests common_regex (esp. its partial final matches support).
#include "common.h"
#include "regex-partial.h"
#include <sstream>
#include <iostream>
#include <optional>
template <class T> static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << " Actual: " << actual << std::endl;
std::cerr << std::flush;
throw std::runtime_error("Test failed");
}
}
struct test_case {
std::string pattern;
struct input_output {
std::string input;
common_regex_match output;
};
std::vector<input_output> inputs_outputs;
};
static std::string common_regex_match_type_name(common_regex_match_type type) {
switch (type) {
case COMMON_REGEX_MATCH_TYPE_NONE:
return "COMMON_REGEX_MATCH_TYPE_NONE";
case COMMON_REGEX_MATCH_TYPE_PARTIAL:
return "COMMON_REGEX_MATCH_TYPE_PARTIAL";
case COMMON_REGEX_MATCH_TYPE_FULL:
return "COMMON_REGEX_MATCH_TYPE_FULL";
}
return "?";
}
static void test_regex() {
printf("[%s]\n", __func__);
auto test = [](const test_case & test_case) {
common_regex cr(test_case.pattern);
std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
// std::cout << " partial rev: " << cr.reversed_partial_pattern.str() << '\n';
for (const auto & input_output : test_case.inputs_outputs) {
std::cout << " Input: " << input_output.input << '\n';
auto m = cr.search(input_output.input, 0);
if (m != input_output.output) {
auto match_to_str = [&](const std::optional<common_regex_match> & m) {
std::ostringstream ss;
if (m->type == COMMON_REGEX_MATCH_TYPE_NONE) {
ss << "<no match>";
} else {
GGML_ASSERT(!input_output.output.groups.empty());
std::vector<std::string> parts;
for (const auto & g : m->groups) {
parts.push_back("{" + std::to_string(g.begin) + ", " + std::to_string(g.end) + "}");
}
ss << "{" << common_regex_match_type_name(m->type) << ", {" << string_join(parts, ", ") << "}}";
}
return ss.str();
};
std::cout << " Expected: " << match_to_str(input_output.output) << '\n';
std::cout << " Got: " << match_to_str(m) << '\n';
std::cout << " Inverted pattern: /" << regex_to_reversed_partial_regex(test_case.pattern) << "/\n";
throw std::runtime_error("Test failed");
}
}
};
test({
"a",
{
{"a", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
{"b", {COMMON_REGEX_MATCH_TYPE_NONE, {}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
{"ba", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 2}}}},
}
});
test({
"abcd",
{
{"abcd", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
{"abcde", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"d", {}},
{"bcd", {}},
{"cde", {}},
{"cd", {}},
{"yeah ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{5, 7}}}},
{"abbie", {}},
{"", {}},
}
});
test({
".*?ab",
{
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"dab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{"dabc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{"da", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
}
});
test({
"a.*?b",
{
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"a b", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"argh", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
{"d", {}},
{"b", {}},
}
});
test({
"ab(?:cd){2,4}ef",
{
// {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"abcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
{"abcde", {}},
{"abcdef", {}},
{"abcdcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"abcdcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 7}}}},
{"abcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
{"abcdcdcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 12}}}},
{"abcdcdcdcdcdef", {}},
{"abcde", {}},
{"yea", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{2, 3}}}},
}
});
test({
"a(?:rte| pure )fact",
{
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"art", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"artefa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"fact", {}},
{"an arte", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{3, 7}}}},
{"artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
{"an artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{3, 11}}}},
{"a pure", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 11}}}},
{"it's a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{5, 16}}}},
{"" , {}},
{"pure", {}},
{"pure fact", {}},
}
});
test({
"abc",
{
{" abcc", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 4}}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{" ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{1, 3}}}},
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"b", {}},
{"c", {}},
{"", {}},
}
});
test({
"(?:abc)?\\s*def",
{
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"abc ", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
{"abc d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
{"abc de", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"abc def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
{"abc defg", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
{"abc defgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
{"abcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
{"abcdefgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 6}}}},
{" d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
}
});
test({
"a+b",
{
{"aaab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
{"aaa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
}
});
test({
"(?:"
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
"(" // match 2 (open_tag)
"<tool_call>"
"|<function_call>"
"|<tool>"
"|<tools>"
"|<response>"
"|<json>"
"|<xml>"
"|<JSON>"
")?"
"(\\s*\\{\\s*\"name\"\\s*:)" // match 3 (named tool call)
")"
"|<function=([^>]+)>" // match 4 (function name)
"|<function name=\"([^\"]+)\">", // match 5 (function name again)
{
{"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}, {54, 54}, {54, 54}, {0, 8}, {54, 54}, {54, 54}}}},
{"<tool_call> {\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 18}}}},
{"<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 17}}}},
{"Let's call something\n<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{21, 38}}}},
{"Ok then<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 24}}}},
{"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 13}}}},
{"<tool_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 20}, {66, 66}, {0, 11}, {11, 20}, {66, 66}, {66, 66}}}},
{"<function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 24}, {70, 70}, {0, 15}, {15, 24}, {70, 70}, {70, 70}}}},
{"<function name=\"special_function\"> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 34}, {89, 89}, {89, 89}, {89, 89}, {89, 89}, {16, 32}}}},
{"<function=all>", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 14}, {14, 14}, {14, 14}, {14, 14}, {10, 13}, {14, 14}}}},
}
});
}
static void test_regex_to_reversed_partial_regex() {
printf("[%s]\n", __func__);
assert_equals<std::string>(
"((?:(?:c)?b)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("abc"));
assert_equals<std::string>(
"(a+)[\\s\\S]*",
regex_to_reversed_partial_regex("a+"));
assert_equals<std::string>(
"(a*)[\\s\\S]*",
regex_to_reversed_partial_regex("a*"));
assert_equals<std::string>(
"(a?)[\\s\\S]*",
regex_to_reversed_partial_regex("a?"));
assert_equals<std::string>(
"([a-z])[\\s\\S]*",
regex_to_reversed_partial_regex("[a-z]"));
assert_equals<std::string>(
"((?:\\w+)?[a-z])[\\s\\S]*",
regex_to_reversed_partial_regex("[a-z]\\w+"));
assert_equals<std::string>(
"((?:a|b))[\\s\\S]*",
regex_to_reversed_partial_regex("(?:a|b)"));
assert_equals<std::string>(
"((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("abcd"));
assert_equals<std::string>(
"((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
regex_to_reversed_partial_regex("a*b"));
assert_equals<std::string>(
"((?:(?:b)?a)?.*)[\\s\\S]*",
regex_to_reversed_partial_regex(".*?ab"));
assert_equals<std::string>(
"((?:(?:b)?.*)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("a.*?b"));
assert_equals<std::string>(
"((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
regex_to_reversed_partial_regex("a(bc)d"));
assert_equals<std::string>(
"((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
regex_to_reversed_partial_regex("a(bc|de)"));
assert_equals<std::string>(
"((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("ab{2,4}c"));
}
int main() {
test_regex_to_reversed_partial_regex();
test_regex();
std::cout << "All tests passed.\n";
}