Files
ik_llama.cpp/common/chat.cpp
Anton Sokolchenko 9ee72225dc Function calling support for Kimi-K2 (#628)
* Implement function calling / tools for ik_llama.cpp for Kimi K2

* Implement basic tool choice

* Backport llama.cpp tool calls support

* Enhance function calls with improved chat parser and string utilities

- Add new chat.h/chat.cpp and chat-parser.h/chat-parser.cpp for better chat handling
- Improve function calls parsing with fallback to llama.cpp builder pattern
- Add string utility functions (starts_with, ends_with, find_partial_stop)
- Update README with function calls testing instructions
- Enhance Kimi K2 parser and function calls documentation
- Add comprehensive test suite for function calls
- Update CMakeLists.txt and Makefile for new components

* Enhance function calling with unified streaming and parser improvements

- Fix streaming content cleanup to prevent function syntax in output
- Unify content extraction patterns with llama.cpp approach
- Improve Kimi K2 parser robustness and partial content handling
- Add comprehensive test coverage for function call scenarios
- Optimize chat message parsing and diff computation

* Replace hardcoded values in kimi_k2_parser.hpp with named constants

- Add compile-time constants for all token format markers
- Add compile-time constants for XML format markers
- Add compile-time constants for simple format patterns
- Replace all hardcoded string literals with named constants
- Use compile-time length calculation to avoid manual counting
- Improve maintainability and reduce magic numbers throughout parser

* Fix duplicate common_chat_parse definition

- Remove duplicate implementation from chat-parser.cpp
- Keep single implementation in chat.cpp following llama.cpp patterns
- Resolves linker error: multiple definition of common_chat_parse

* Fix JSON assertion failure in function call parsing

- Add proper validation that 'function' field is an object before accessing nested keys
- Handle missing 'arguments' field gracefully with default "{}"
- Prevents crash when parsing malformed tool call JSON structures

* Add comprehensive Qwen3 XML tool calling support with unit tests

- Implement Qwen3 XML parser with <tool_call>{"name": "func", "arguments": {...}}</tool_call> format
- Add model detection and routing for Qwen3 vs Kimi-K2 formats
- Create 8 comprehensive unit tests covering parsing, streaming, error handling
- Fix token format cleaning bug in kimi_k2_parser.hpp processing order
- Remove progressive parsing code and related utilities
- Add tool injection support for Qwen3 format in server utils

* Add DeepSeek R1 function calling support with comprehensive unit tests

- Implement complete DeepSeek R1 tool call parsing in common_chat_parser.cpp
- Add DeepSeek R1 model detection and tool injection in deepseek_r1_tools.hpp
- Update function_calls.hpp with DeepSeek R1 integration and content extraction
- Update documentation to reflect support for Kimi-K2, Qwen3, and DeepSeek R1 models
- Add comprehensive unit tests for DeepSeek R1 reasoning, tool calls, and integration
- Port exact implementation patterns from original llama.cpp for compatibility

Key features:
- Native DeepSeek R1 format: <|tool▁calls▁begin|>function<|tool▁sep|>name```json{}```<|tool▁call▁end|><|tool▁calls▁end|>
- Reasoning content extraction from <think>...</think> tags
- Multiple tool calls support with separate call blocks
- Model detection for deepseek-r1, deepseek_r1 naming patterns
- Integration with incremental parsing and streaming support

* Add partial parsing support for JSON and regex

- json-partial.h/cpp: JSON partial parsing functionality
- regex-partial.h/cpp: Regex partial parsing functionality

* Add format_chat integration tests for Qwen3 tool injection

- Add test_qwen3_format_chat_integration() to validate tool injection pipeline
- Test tool injection conditions and system message enhancement
- Verify JSON formatting and anti-preamble instructions
- Add comprehensive test documentation

Tests confirm tool injection works correctly - conversational preamble
issue is not in ik_llama.cpp but likely in UI configuration.

* Fix Qwen3 tool call parsing - pass model name to parser

Server was not passing model name to parse_chat_message_incremental(),
causing Qwen3 to fall back to Kimi-K2 parser and return tool calls
as content instead of proper tool_calls array.

* Fix non-streaming path to use model-specific parsing

Non-streaming responses were hardcoded to use Kimi-K2 format,
causing Qwen3 XML tool calls to be returned as content instead
of proper tool_calls array. Now uses same model detection as
streaming path for consistency.
2025-07-23 18:11:42 +02:00

204 lines
8.3 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "chat.h"
#include "chat-parser.h"
#include "common.h"
#include "../examples/server/parsers/kimi_k2_parser.hpp"
#include <stdexcept>
#include <string>
#include <vector>
#include "json.hpp"
using json = nlohmann::ordered_json;
static std::string string_diff(const std::string & last, const std::string & current) {
if (last.empty()) {
return current;
}
if (!string_starts_with(current, last)) {
if (string_starts_with(last, current)) {
// This happens if the last generation ended on a partial stop word (not erased),
// and the current ended on a stop word (erased).
return "";
}
throw std::runtime_error("Invalid diff: '" + last + "' not found at start of '" + current + "'");
}
return current.substr(last.size());
}
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) {
std::vector<common_chat_msg_diff> diffs;
if (previous_msg.reasoning_content != new_msg.reasoning_content) {
auto & diff = diffs.emplace_back();
diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content);
}
if (previous_msg.content != new_msg.content) {
auto & diff = diffs.emplace_back();
diff.content_delta = string_diff(previous_msg.content, new_msg.content);
}
if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) {
throw std::runtime_error("Invalid diff: now finding less tool calls!");
}
if (!previous_msg.tool_calls.empty()) {
auto idx = previous_msg.tool_calls.size() - 1;
const auto & pref = previous_msg.tool_calls[idx];
const auto & newf = new_msg.tool_calls[idx];
if (pref.name != newf.name) {
throw std::runtime_error("Invalid diff: tool call mismatch!");
}
auto args_diff = string_diff(pref.arguments, newf.arguments);
if (!args_diff.empty() || pref.id != newf.id) {
auto & diff = diffs.emplace_back();
diff.tool_call_index = idx;
if (pref.id != newf.id) {
diff.tool_call_delta.id = newf.id;
diff.tool_call_delta.name = newf.name;
}
diff.tool_call_delta.arguments = args_diff;
}
}
for (size_t idx = previous_msg.tool_calls.size(); idx < new_msg.tool_calls.size(); ++idx) {
auto & diff = diffs.emplace_back();
diff.tool_call_index = idx;
diff.tool_call_delta = new_msg.tool_calls[idx];
}
return diffs;
}
// Format parsing functions (ported from original llama.cpp)
// Content-only parsing (internal implementation - matches llama.cpp exactly)
static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_generic(common_chat_msg_parser & builder) {
if (!builder.syntax().enable_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const std::vector<std::vector<std::string>> content_paths = {
{"response"},
};
static const std::vector<std::vector<std::string>> args_paths = {
{"tool_call", "arguments"},
{"tool_calls", "arguments"},
};
auto data = builder.consume_json_with_dumped_args(args_paths, content_paths);
if (data.value.contains("tool_calls")) {
if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool calls");
}
} else if (data.value.contains("tool_call")) {
if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
} else if (data.value.contains("response")) {
const auto & response = data.value.at("response");
builder.add_content(response.is_string() ? response.template get<std::string>() : response.dump(2));
if (data.is_partial) {
throw common_chat_msg_partial_exception("incomplete response");
}
} else {
throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON");
}
}
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().enable_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
static const common_regex function_regex("(?:<tool▁call▁begin>)?function<tool▁sep>([^\n]+)\n```json\n");
static const common_regex close_regex("```[\\s\\r\\n]*<tool▁call▁end>");
// Simplified tool calls parsing for DEEPSEEK_R1
if (auto res = builder.try_find_regex(tool_calls_begin)) {
while (auto func_res = builder.try_find_regex(function_regex)) {
auto function_name = builder.str(func_res->groups[1]);
auto args_json = builder.try_consume_json();
if (args_json) {
builder.add_tool_call(function_name, "", args_json->json.dump());
builder.try_consume_regex(close_regex);
} else {
throw common_chat_msg_partial_exception("incomplete tool call JSON");
}
}
builder.try_consume_regex(tool_calls_end);
builder.add_content(builder.consume_rest());
} else {
builder.add_content(builder.consume_rest());
}
}
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
// Delegate to existing Kimi-K2 implementation for backward compatibility
auto result = kimi_k2::parse_tool_calls(builder.input());
for (const auto& tc_json : result) {
common_chat_tool_call tc;
tc.id = tc_json.value("id", "");
if (tc_json.contains("function") && tc_json["function"].contains("name")) {
tc.name = tc_json["function"]["name"];
tc.arguments = tc_json["function"].value("arguments", "{}");
builder.add_tool_call(tc);
}
}
// Add cleaned content (removes tool call syntax)
builder.add_content(kimi_k2::clean_content(builder.input()));
}
// Main parsing dispatch function
static void common_chat_parse(common_chat_msg_parser & builder) {
switch (builder.syntax().format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
common_chat_parse_content_only(builder);
break;
case COMMON_CHAT_FORMAT_GENERIC:
common_chat_parse_generic(builder);
break;
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
common_chat_parse_deepseek_r1(builder);
break;
case COMMON_CHAT_FORMAT_KIMI_K2:
common_chat_parse_kimi_k2(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
builder.finish();
}
// Main public parsing function
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
common_chat_msg_parser builder(input, is_partial, syntax);
try {
common_chat_parse(builder);
} catch (const common_chat_msg_partial_exception & ex) {
if (!is_partial) {
// Fallback to content-only on parsing errors
builder.clear_tools();
builder.move_to(0);
common_chat_parse_content_only(builder);
}
// Re-throw for partial cases to signal incomplete parsing
if (is_partial) {
throw;
}
}
return builder.result();
}
// Get format name for debugging/logging
const char* common_chat_format_name(common_chat_format format) {
switch (format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "content_only";
case COMMON_CHAT_FORMAT_GENERIC: return "generic";
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "deepseek_r1";
case COMMON_CHAT_FORMAT_KIMI_K2: return "kimi_k2";
default: return "unknown";
}
}