Function calling support for Kimi-K2 (#628)

* Implement function calling / tools for ik_llama.cpp for Kimi K2

* Implement basic tool choice

* Backport llama.cpp tool calls support

* Enhance function calls with improved chat parser and string utilities

- Add new chat.h/chat.cpp and chat-parser.h/chat-parser.cpp for better chat handling
- Improve function calls parsing with fallback to llama.cpp builder pattern
- Add string utility functions (starts_with, ends_with, find_partial_stop)
- Update README with function calls testing instructions
- Enhance Kimi K2 parser and function calls documentation
- Add comprehensive test suite for function calls
- Update CMakeLists.txt and Makefile for new components

* Enhance function calling with unified streaming and parser improvements

- Fix streaming content cleanup to prevent function syntax in output
- Unify content extraction patterns with llama.cpp approach
- Improve Kimi K2 parser robustness and partial content handling
- Add comprehensive test coverage for function call scenarios
- Optimize chat message parsing and diff computation

* Replace hardcoded values in kimi_k2_parser.hpp with named constants

- Add compile-time constants for all token format markers
- Add compile-time constants for XML format markers
- Add compile-time constants for simple format patterns
- Replace all hardcoded string literals with named constants
- Use compile-time length calculation to avoid manual counting
- Improve maintainability and reduce magic numbers throughout parser

* Fix duplicate common_chat_parse definition

- Remove duplicate implementation from chat-parser.cpp
- Keep single implementation in chat.cpp following llama.cpp patterns
- Resolves linker error: multiple definition of common_chat_parse

* Fix JSON assertion failure in function call parsing

- Add proper validation that 'function' field is an object before accessing nested keys
- Handle missing 'arguments' field gracefully with default "{}"
- Prevents crash when parsing malformed tool call JSON structures

* Add comprehensive Qwen3 XML tool calling support with unit tests

- Implement Qwen3 XML parser with <tool_call>{"name": "func", "arguments": {...}}</tool_call> format
- Add model detection and routing for Qwen3 vs Kimi-K2 formats
- Create 8 comprehensive unit tests covering parsing, streaming, error handling
- Fix token format cleaning bug in kimi_k2_parser.hpp processing order
- Remove progressive parsing code and related utilities
- Add tool injection support for Qwen3 format in server utils

* Add DeepSeek R1 function calling support with comprehensive unit tests

- Implement complete DeepSeek R1 tool call parsing in common_chat_parser.cpp
- Add DeepSeek R1 model detection and tool injection in deepseek_r1_tools.hpp
- Update function_calls.hpp with DeepSeek R1 integration and content extraction
- Update documentation to reflect support for Kimi-K2, Qwen3, and DeepSeek R1 models
- Add comprehensive unit tests for DeepSeek R1 reasoning, tool calls, and integration
- Port exact implementation patterns from original llama.cpp for compatibility

Key features:
- Native DeepSeek R1 format: <|tool▁calls▁begin|>function<|tool▁sep|>name```json{}```<|tool▁call▁end|><|tool▁calls▁end|>
- Reasoning content extraction from <think>...</think> tags
- Multiple tool calls support with separate call blocks
- Model detection for deepseek-r1, deepseek_r1 naming patterns
- Integration with incremental parsing and streaming support

* Add partial parsing support for JSON and regex

- json-partial.h/cpp: JSON partial parsing functionality
- regex-partial.h/cpp: Regex partial parsing functionality

* Add format_chat integration tests for Qwen3 tool injection

- Add test_qwen3_format_chat_integration() to validate tool injection pipeline
- Test tool injection conditions and system message enhancement
- Verify JSON formatting and anti-preamble instructions
- Add comprehensive test documentation

Tests confirm tool injection works correctly - conversational preamble
issue is not in ik_llama.cpp but likely in UI configuration.

* Fix Qwen3 tool call parsing - pass model name to parser

Server was not passing model name to parse_chat_message_incremental(),
causing Qwen3 to fall back to Kimi-K2 parser and return tool calls
as content instead of proper tool_calls array.

* Fix non-streaming path to use model-specific parsing

Non-streaming responses were hardcoded to use Kimi-K2 format,
causing Qwen3 XML tool calls to be returned as content instead
of proper tool_calls array. Now uses same model detection as
streaming path for consistency.
This commit is contained in:
Anton Sokolchenko
2025-07-23 18:11:42 +02:00
committed by GitHub
parent 0451f10a42
commit 3701fb1686
26 changed files with 6978 additions and 9 deletions

View File

@@ -20,6 +20,9 @@
#include "json.hpp"
#include "index.html.gz.hpp"
#include "loading.html.hpp"
#include "function_calls.hpp"
#include "streaming_chat.hpp"
#include "../../common/chat-parser.h"
#include <atomic>
#include <chrono>
@@ -30,6 +33,8 @@
#include <thread>
#include <signal.h>
#include <memory>
#include <random>
#include <algorithm>
#include <src/llama-impl.h>
using json = nlohmann::ordered_json;
@@ -38,6 +43,7 @@ bool server_verbose = false;
bool server_log_json = true;
enum stop_type {
STOP_TYPE_FULL,
STOP_TYPE_PARTIAL,
@@ -135,6 +141,74 @@ struct server_task_result {
std::unordered_map<int, server_task_result > server_task_result_dict = {};
// Helper functions for content cleaning
static std::string remove_simple_function_calls(const std::string& content) {
std::string cleaned = content;
const std::string func_pattern = "functions.";
size_t pos = 0;
while ((pos = cleaned.find(func_pattern, pos)) != std::string::npos) {
size_t func_start = pos;
// Find the opening brace for arguments
size_t brace_pos = cleaned.find('{', pos);
if (brace_pos == std::string::npos) {
pos += func_pattern.length();
continue;
}
// Find the matching closing brace
int brace_count = 1;
size_t end_pos = brace_pos + 1;
while (end_pos < cleaned.length() && brace_count > 0) {
if (cleaned[end_pos] == '{') brace_count++;
else if (cleaned[end_pos] == '}') brace_count--;
end_pos++;
}
if (brace_count == 0) {
// Remove the entire function call
cleaned.erase(func_start, end_pos - func_start);
pos = func_start;
} else {
pos += func_pattern.length();
}
}
return cleaned;
}
static std::string remove_xml_function_calls(const std::string& content) {
std::string cleaned = content;
size_t pos = 0;
while ((pos = cleaned.find("<tool_call>", pos)) != std::string::npos) {
size_t tool_call_start = pos;
size_t tool_call_end = cleaned.find("</tool_call>", tool_call_start);
if (tool_call_end == std::string::npos) {
pos = tool_call_start + 11;
continue;
}
// Remove the entire XML tool call block
cleaned.erase(tool_call_start, tool_call_end - tool_call_start + 12);
pos = tool_call_start;
}
return cleaned;
}
static std::string clean_all_function_call_formats(const std::string& content) {
std::string cleaned = content;
// Remove XML format first
cleaned = remove_xml_function_calls(cleaned);
// Then remove simple format
cleaned = remove_simple_function_calls(cleaned);
// Trim whitespace from cleaned content
cleaned.erase(0, cleaned.find_first_not_of(" \t\n\r"));
cleaned.erase(cleaned.find_last_not_of(" \t\n\r") + 1);
return cleaned;
}
struct server_task_multi {
int id = -1;
@@ -191,6 +265,11 @@ struct server_slot {
std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs;
// Streaming tool call state
ik_chat_msg previous_msg;
ik_chat_msg current_msg;
std::vector<std::string> tool_call_ids;
bool infill = false;
bool embedding = false;
bool has_next_token = true;
@@ -242,6 +321,37 @@ struct server_slot {
n_past_se = 0;
generated_token_probs.clear();
// Reset streaming tool call state
previous_msg = ik_chat_msg();
current_msg = ik_chat_msg();
tool_call_ids.clear();
}
// Update chat message and compute diffs for streaming tool calls
// Based on original llama.cpp update_chat_msg pattern
const ik_chat_msg & update_chat_msg(std::vector<ik_chat_msg_diff> & diffs) {
ik_chat_msg previous = current_msg;
try {
// Parse generated text incrementally (is_partial = true during generation)
bool is_partial = !stopped_eos && !stopped_word && !stopped_limit;
ik_chat_msg new_msg = parse_chat_message_incremental(generated_text, is_partial, oaicompat_model);
if (!new_msg.empty()) {
// Ensure tool call IDs are set consistently across streaming chunks
new_msg.ensure_tool_call_ids_set(tool_call_ids, generate_tool_call_id);
current_msg = new_msg;
// Compute diffs for streaming
diffs = ik_chat_msg_diff::compute_diffs(previous, current_msg);
}
} catch (const std::exception& e) {
// If parsing fails, don't update current_msg and return empty diffs
diffs.clear();
}
return current_msg;
}
bool has_budget(gpt_params &global_params) {
@@ -1499,13 +1609,43 @@ struct server_context {
res.id_multi = slot.id_multi;
res.error = false;
res.stop = false;
// Update chat message and compute diffs for streaming tool calls
// Following original llama.cpp pattern (server.cpp:2503)
std::vector<ik_chat_msg_diff> oaicompat_msg_diffs;
slot.update_chat_msg(oaicompat_msg_diffs);
// Following original llama.cpp pattern: send empty content in streaming mode
// Clean content comes through oaicompat_msg_diffs instead of raw tokens
res.data = json {
{"content", tkn.text_to_send},
{"content", ""}, // Empty - clean content provided via diffs
{"stop", false},
{"id_slot", slot.id},
{"multimodal", false}
};
// Store diffs for format_partial_response_oaicompat to use
// Convert ik_chat_msg_diff to JSON format for storage
json diffs_json = json::array();
for (const auto & diff : oaicompat_msg_diffs) {
json diff_obj;
if (!diff.content_delta.empty()) {
diff_obj["content_delta"] = diff.content_delta;
}
if (diff.tool_call_index != std::string::npos) {
diff_obj["tool_call_index"] = diff.tool_call_index;
diff_obj["tool_call_delta"] = {
{"id", diff.tool_call_delta.id},
{"name", diff.tool_call_delta.name},
{"arguments", diff.tool_call_delta.arguments}
};
}
if (!diff_obj.empty()) {
diffs_json.push_back(diff_obj);
}
}
res.data["oaicompat_msg_diffs"] = diffs_json;
if (slot.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
@@ -2587,19 +2727,57 @@ static json format_final_response_oaicompat(const json& request, json result, co
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
std::string content = json_value(result, "content", std::string(""));
// Parse tool calls using model-specific format detection
std::string model_name = json_value(request, "model", std::string(""));
// Use the same parsing logic as streaming path for consistency
ik_chat_msg parsed_msg = parse_chat_message_incremental(content, false, model_name);
// Convert to JSON format for compatibility
json tool_calls = json::array();
for (const auto & tc : parsed_msg.tool_calls) {
tool_calls.push_back({
{"type", "function"},
{"function", {
{"name", tc.name},
{"arguments", tc.arguments}
}},
{"id", tc.id}
});
}
bool has_tool_calls = !tool_calls.empty();
// Use cleaned content from parser (following original llama.cpp pattern)
if (has_tool_calls) {
content = parsed_msg.content; // Parser already cleaned the content
}
std::string finish_reason = "length";
if (stopped_word || stopped_eos) {
if (has_tool_calls) {
finish_reason = "tool_calls";
} else if (stopped_word || stopped_eos) {
finish_reason = "stop";
}
json message = json{{"role", "assistant"}};
// Follow EXACT original llama.cpp pattern: content is null only when content is empty AND tool calls exist
if (content.empty() && has_tool_calls) {
message["content"] = nullptr; // Original: json() when content empty AND tool calls exist
} else {
message["content"] = content.empty() ? nullptr : content; // Original: use actual content otherwise
}
if (has_tool_calls) {
message["tool_calls"] = tool_calls;
}
json choices =
streaming ? json::array({ json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}} })
: json::array({ json{{"finish_reason", finish_reason},
{"index", 0},
{"message", json{{"content", content},
{"role", "assistant"}}}} });
{"message", message}} });
std::time_t t = std::time(0);
@@ -2653,6 +2831,59 @@ static std::vector<json> format_partial_response_oaicompat(server_task_result ta
std::time_t t = std::time(0);
// Follow original llama.cpp pattern: Always process diffs and add final chunk
std::vector<json> streaming_chunks;
// Extract diffs from task result (populated by send_partial_response)
// Following original llama.cpp pattern where diffs are stored in task result
std::vector<ik_chat_msg_diff> diffs;
if (result.contains("oaicompat_msg_diffs") && result["oaicompat_msg_diffs"].is_array()) {
for (const auto & diff_json : result["oaicompat_msg_diffs"]) {
ik_chat_msg_diff diff;
// Extract content delta
diff.content_delta = diff_json.value("content_delta", "");
// Extract tool call data
if (diff_json.contains("tool_call_index")) {
diff.tool_call_index = diff_json["tool_call_index"];
if (diff_json.contains("tool_call_delta")) {
const auto & tc_delta = diff_json["tool_call_delta"];
diff.tool_call_delta.id = tc_delta.value("id", "");
diff.tool_call_delta.name = tc_delta.value("name", "");
diff.tool_call_delta.arguments = tc_delta.value("arguments", "");
}
} else {
diff.tool_call_index = std::string::npos;
}
diffs.push_back(diff);
}
}
streaming_chunks = generate_streaming_chunks(diffs, completion_id, modelname);
// Always add final chunk (like original llama.cpp)
if (!finish_reason.empty()) {
json finish_chunk = {
{"choices", json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}})},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}
};
streaming_chunks.push_back(finish_chunk);
}
// Return streaming chunks (could be just final chunk if no diffs)
if (!streaming_chunks.empty()) {
return streaming_chunks;
}
// Fallback to original streaming logic for non-tool calls
json choices;
if (!finish_reason.empty()) {
@@ -2812,6 +3043,7 @@ int main(int argc, char ** argv) {
// TODO: not great to use extern vars
server_log_json = params.log_json;
server_verbose = params.verbosity > 0;
// struct that contains llama context and inference
server_context ctx_server;