mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-10 00:10:13 +00:00
Function calling support for Kimi-K2 (#628)
* Implement function calling / tools for ik_llama.cpp for Kimi K2
* Implement basic tool choice
* Backport llama.cpp tool calls support
* Enhance function calls with improved chat parser and string utilities
- Add new chat.h/chat.cpp and chat-parser.h/chat-parser.cpp for better chat handling
- Improve function calls parsing with fallback to llama.cpp builder pattern
- Add string utility functions (starts_with, ends_with, find_partial_stop)
- Update README with function calls testing instructions
- Enhance Kimi K2 parser and function calls documentation
- Add comprehensive test suite for function calls
- Update CMakeLists.txt and Makefile for new components
* Enhance function calling with unified streaming and parser improvements
- Fix streaming content cleanup to prevent function syntax in output
- Unify content extraction patterns with llama.cpp approach
- Improve Kimi K2 parser robustness and partial content handling
- Add comprehensive test coverage for function call scenarios
- Optimize chat message parsing and diff computation
* Replace hardcoded values in kimi_k2_parser.hpp with named constants
- Add compile-time constants for all token format markers
- Add compile-time constants for XML format markers
- Add compile-time constants for simple format patterns
- Replace all hardcoded string literals with named constants
- Use compile-time length calculation to avoid manual counting
- Improve maintainability and reduce magic numbers throughout parser
* Fix duplicate common_chat_parse definition
- Remove duplicate implementation from chat-parser.cpp
- Keep single implementation in chat.cpp following llama.cpp patterns
- Resolves linker error: multiple definition of common_chat_parse
* Fix JSON assertion failure in function call parsing
- Add proper validation that 'function' field is an object before accessing nested keys
- Handle missing 'arguments' field gracefully with default "{}"
- Prevents crash when parsing malformed tool call JSON structures
* Add comprehensive Qwen3 XML tool calling support with unit tests
- Implement Qwen3 XML parser with <tool_call>{"name": "func", "arguments": {...}}</tool_call> format
- Add model detection and routing for Qwen3 vs Kimi-K2 formats
- Create 8 comprehensive unit tests covering parsing, streaming, error handling
- Fix token format cleaning bug in kimi_k2_parser.hpp processing order
- Remove progressive parsing code and related utilities
- Add tool injection support for Qwen3 format in server utils
* Add DeepSeek R1 function calling support with comprehensive unit tests
- Implement complete DeepSeek R1 tool call parsing in common_chat_parser.cpp
- Add DeepSeek R1 model detection and tool injection in deepseek_r1_tools.hpp
- Update function_calls.hpp with DeepSeek R1 integration and content extraction
- Update documentation to reflect support for Kimi-K2, Qwen3, and DeepSeek R1 models
- Add comprehensive unit tests for DeepSeek R1 reasoning, tool calls, and integration
- Port exact implementation patterns from original llama.cpp for compatibility
Key features:
- Native DeepSeek R1 format: <|tool▁calls▁begin|>function<|tool▁sep|>name```json{}```<|tool▁call▁end|><|tool▁calls▁end|>
- Reasoning content extraction from <think>...</think> tags
- Multiple tool calls support with separate call blocks
- Model detection for deepseek-r1, deepseek_r1 naming patterns
- Integration with incremental parsing and streaming support
* Add partial parsing support for JSON and regex
- json-partial.h/cpp: JSON partial parsing functionality
- regex-partial.h/cpp: Regex partial parsing functionality
* Add format_chat integration tests for Qwen3 tool injection
- Add test_qwen3_format_chat_integration() to validate tool injection pipeline
- Test tool injection conditions and system message enhancement
- Verify JSON formatting and anti-preamble instructions
- Add comprehensive test documentation
Tests confirm tool injection works correctly - conversational preamble
issue is not in ik_llama.cpp but likely in UI configuration.
* Fix Qwen3 tool call parsing - pass model name to parser
Server was not passing model name to parse_chat_message_incremental(),
causing Qwen3 to fall back to Kimi-K2 parser and return tool calls
as content instead of proper tool_calls array.
* Fix non-streaming path to use model-specific parsing
Non-streaming responses were hardcoded to use Kimi-K2 format,
causing Qwen3 XML tool calls to be returned as content instead
of proper tool_calls array. Now uses same model detection as
streaming path for consistency.
This commit is contained in:
committed by
GitHub
parent
0451f10a42
commit
3701fb1686
@@ -20,6 +20,9 @@
|
||||
#include "json.hpp"
|
||||
#include "index.html.gz.hpp"
|
||||
#include "loading.html.hpp"
|
||||
#include "function_calls.hpp"
|
||||
#include "streaming_chat.hpp"
|
||||
#include "../../common/chat-parser.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
@@ -30,6 +33,8 @@
|
||||
#include <thread>
|
||||
#include <signal.h>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <algorithm>
|
||||
#include <src/llama-impl.h>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
@@ -38,6 +43,7 @@ bool server_verbose = false;
|
||||
bool server_log_json = true;
|
||||
|
||||
|
||||
|
||||
enum stop_type {
|
||||
STOP_TYPE_FULL,
|
||||
STOP_TYPE_PARTIAL,
|
||||
@@ -135,6 +141,74 @@ struct server_task_result {
|
||||
|
||||
std::unordered_map<int, server_task_result > server_task_result_dict = {};
|
||||
|
||||
// Helper functions for content cleaning
|
||||
static std::string remove_simple_function_calls(const std::string& content) {
|
||||
std::string cleaned = content;
|
||||
const std::string func_pattern = "functions.";
|
||||
size_t pos = 0;
|
||||
while ((pos = cleaned.find(func_pattern, pos)) != std::string::npos) {
|
||||
size_t func_start = pos;
|
||||
|
||||
// Find the opening brace for arguments
|
||||
size_t brace_pos = cleaned.find('{', pos);
|
||||
if (brace_pos == std::string::npos) {
|
||||
pos += func_pattern.length();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find the matching closing brace
|
||||
int brace_count = 1;
|
||||
size_t end_pos = brace_pos + 1;
|
||||
while (end_pos < cleaned.length() && brace_count > 0) {
|
||||
if (cleaned[end_pos] == '{') brace_count++;
|
||||
else if (cleaned[end_pos] == '}') brace_count--;
|
||||
end_pos++;
|
||||
}
|
||||
|
||||
if (brace_count == 0) {
|
||||
// Remove the entire function call
|
||||
cleaned.erase(func_start, end_pos - func_start);
|
||||
pos = func_start;
|
||||
} else {
|
||||
pos += func_pattern.length();
|
||||
}
|
||||
}
|
||||
return cleaned;
|
||||
}
|
||||
|
||||
static std::string remove_xml_function_calls(const std::string& content) {
|
||||
std::string cleaned = content;
|
||||
size_t pos = 0;
|
||||
while ((pos = cleaned.find("<tool_call>", pos)) != std::string::npos) {
|
||||
size_t tool_call_start = pos;
|
||||
size_t tool_call_end = cleaned.find("</tool_call>", tool_call_start);
|
||||
if (tool_call_end == std::string::npos) {
|
||||
pos = tool_call_start + 11;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Remove the entire XML tool call block
|
||||
cleaned.erase(tool_call_start, tool_call_end - tool_call_start + 12);
|
||||
pos = tool_call_start;
|
||||
}
|
||||
return cleaned;
|
||||
}
|
||||
|
||||
static std::string clean_all_function_call_formats(const std::string& content) {
|
||||
std::string cleaned = content;
|
||||
|
||||
// Remove XML format first
|
||||
cleaned = remove_xml_function_calls(cleaned);
|
||||
|
||||
// Then remove simple format
|
||||
cleaned = remove_simple_function_calls(cleaned);
|
||||
|
||||
// Trim whitespace from cleaned content
|
||||
cleaned.erase(0, cleaned.find_first_not_of(" \t\n\r"));
|
||||
cleaned.erase(cleaned.find_last_not_of(" \t\n\r") + 1);
|
||||
|
||||
return cleaned;
|
||||
}
|
||||
|
||||
struct server_task_multi {
|
||||
int id = -1;
|
||||
@@ -191,6 +265,11 @@ struct server_slot {
|
||||
std::vector<llama_token> cache_tokens;
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
|
||||
// Streaming tool call state
|
||||
ik_chat_msg previous_msg;
|
||||
ik_chat_msg current_msg;
|
||||
std::vector<std::string> tool_call_ids;
|
||||
|
||||
bool infill = false;
|
||||
bool embedding = false;
|
||||
bool has_next_token = true;
|
||||
@@ -242,6 +321,37 @@ struct server_slot {
|
||||
n_past_se = 0;
|
||||
|
||||
generated_token_probs.clear();
|
||||
|
||||
// Reset streaming tool call state
|
||||
previous_msg = ik_chat_msg();
|
||||
current_msg = ik_chat_msg();
|
||||
tool_call_ids.clear();
|
||||
}
|
||||
|
||||
// Update chat message and compute diffs for streaming tool calls
|
||||
// Based on original llama.cpp update_chat_msg pattern
|
||||
const ik_chat_msg & update_chat_msg(std::vector<ik_chat_msg_diff> & diffs) {
|
||||
ik_chat_msg previous = current_msg;
|
||||
|
||||
try {
|
||||
// Parse generated text incrementally (is_partial = true during generation)
|
||||
bool is_partial = !stopped_eos && !stopped_word && !stopped_limit;
|
||||
ik_chat_msg new_msg = parse_chat_message_incremental(generated_text, is_partial, oaicompat_model);
|
||||
|
||||
if (!new_msg.empty()) {
|
||||
// Ensure tool call IDs are set consistently across streaming chunks
|
||||
new_msg.ensure_tool_call_ids_set(tool_call_ids, generate_tool_call_id);
|
||||
current_msg = new_msg;
|
||||
|
||||
// Compute diffs for streaming
|
||||
diffs = ik_chat_msg_diff::compute_diffs(previous, current_msg);
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
// If parsing fails, don't update current_msg and return empty diffs
|
||||
diffs.clear();
|
||||
}
|
||||
|
||||
return current_msg;
|
||||
}
|
||||
|
||||
bool has_budget(gpt_params &global_params) {
|
||||
@@ -1499,13 +1609,43 @@ struct server_context {
|
||||
res.id_multi = slot.id_multi;
|
||||
res.error = false;
|
||||
res.stop = false;
|
||||
|
||||
// Update chat message and compute diffs for streaming tool calls
|
||||
// Following original llama.cpp pattern (server.cpp:2503)
|
||||
std::vector<ik_chat_msg_diff> oaicompat_msg_diffs;
|
||||
slot.update_chat_msg(oaicompat_msg_diffs);
|
||||
|
||||
// Following original llama.cpp pattern: send empty content in streaming mode
|
||||
// Clean content comes through oaicompat_msg_diffs instead of raw tokens
|
||||
res.data = json {
|
||||
{"content", tkn.text_to_send},
|
||||
{"content", ""}, // Empty - clean content provided via diffs
|
||||
{"stop", false},
|
||||
{"id_slot", slot.id},
|
||||
{"multimodal", false}
|
||||
};
|
||||
|
||||
// Store diffs for format_partial_response_oaicompat to use
|
||||
// Convert ik_chat_msg_diff to JSON format for storage
|
||||
json diffs_json = json::array();
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
json diff_obj;
|
||||
if (!diff.content_delta.empty()) {
|
||||
diff_obj["content_delta"] = diff.content_delta;
|
||||
}
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
diff_obj["tool_call_index"] = diff.tool_call_index;
|
||||
diff_obj["tool_call_delta"] = {
|
||||
{"id", diff.tool_call_delta.id},
|
||||
{"name", diff.tool_call_delta.name},
|
||||
{"arguments", diff.tool_call_delta.arguments}
|
||||
};
|
||||
}
|
||||
if (!diff_obj.empty()) {
|
||||
diffs_json.push_back(diff_obj);
|
||||
}
|
||||
}
|
||||
res.data["oaicompat_msg_diffs"] = diffs_json;
|
||||
|
||||
if (slot.sparams.n_probs > 0) {
|
||||
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
|
||||
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
||||
@@ -2587,19 +2727,57 @@ static json format_final_response_oaicompat(const json& request, json result, co
|
||||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
||||
std::string content = json_value(result, "content", std::string(""));
|
||||
|
||||
// Parse tool calls using model-specific format detection
|
||||
std::string model_name = json_value(request, "model", std::string(""));
|
||||
|
||||
// Use the same parsing logic as streaming path for consistency
|
||||
ik_chat_msg parsed_msg = parse_chat_message_incremental(content, false, model_name);
|
||||
|
||||
// Convert to JSON format for compatibility
|
||||
json tool_calls = json::array();
|
||||
for (const auto & tc : parsed_msg.tool_calls) {
|
||||
tool_calls.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tc.name},
|
||||
{"arguments", tc.arguments}
|
||||
}},
|
||||
{"id", tc.id}
|
||||
});
|
||||
}
|
||||
|
||||
bool has_tool_calls = !tool_calls.empty();
|
||||
|
||||
// Use cleaned content from parser (following original llama.cpp pattern)
|
||||
if (has_tool_calls) {
|
||||
content = parsed_msg.content; // Parser already cleaned the content
|
||||
}
|
||||
|
||||
std::string finish_reason = "length";
|
||||
if (stopped_word || stopped_eos) {
|
||||
if (has_tool_calls) {
|
||||
finish_reason = "tool_calls";
|
||||
} else if (stopped_word || stopped_eos) {
|
||||
finish_reason = "stop";
|
||||
}
|
||||
|
||||
json message = json{{"role", "assistant"}};
|
||||
// Follow EXACT original llama.cpp pattern: content is null only when content is empty AND tool calls exist
|
||||
if (content.empty() && has_tool_calls) {
|
||||
message["content"] = nullptr; // Original: json() when content empty AND tool calls exist
|
||||
} else {
|
||||
message["content"] = content.empty() ? nullptr : content; // Original: use actual content otherwise
|
||||
}
|
||||
if (has_tool_calls) {
|
||||
message["tool_calls"] = tool_calls;
|
||||
}
|
||||
|
||||
json choices =
|
||||
streaming ? json::array({ json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}} })
|
||||
: json::array({ json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", json{{"content", content},
|
||||
{"role", "assistant"}}}} });
|
||||
{"message", message}} });
|
||||
|
||||
std::time_t t = std::time(0);
|
||||
|
||||
@@ -2653,6 +2831,59 @@ static std::vector<json> format_partial_response_oaicompat(server_task_result ta
|
||||
|
||||
std::time_t t = std::time(0);
|
||||
|
||||
// Follow original llama.cpp pattern: Always process diffs and add final chunk
|
||||
std::vector<json> streaming_chunks;
|
||||
|
||||
// Extract diffs from task result (populated by send_partial_response)
|
||||
// Following original llama.cpp pattern where diffs are stored in task result
|
||||
std::vector<ik_chat_msg_diff> diffs;
|
||||
|
||||
if (result.contains("oaicompat_msg_diffs") && result["oaicompat_msg_diffs"].is_array()) {
|
||||
for (const auto & diff_json : result["oaicompat_msg_diffs"]) {
|
||||
ik_chat_msg_diff diff;
|
||||
|
||||
// Extract content delta
|
||||
diff.content_delta = diff_json.value("content_delta", "");
|
||||
|
||||
// Extract tool call data
|
||||
if (diff_json.contains("tool_call_index")) {
|
||||
diff.tool_call_index = diff_json["tool_call_index"];
|
||||
if (diff_json.contains("tool_call_delta")) {
|
||||
const auto & tc_delta = diff_json["tool_call_delta"];
|
||||
diff.tool_call_delta.id = tc_delta.value("id", "");
|
||||
diff.tool_call_delta.name = tc_delta.value("name", "");
|
||||
diff.tool_call_delta.arguments = tc_delta.value("arguments", "");
|
||||
}
|
||||
} else {
|
||||
diff.tool_call_index = std::string::npos;
|
||||
}
|
||||
|
||||
diffs.push_back(diff);
|
||||
}
|
||||
}
|
||||
|
||||
streaming_chunks = generate_streaming_chunks(diffs, completion_id, modelname);
|
||||
|
||||
// Always add final chunk (like original llama.cpp)
|
||||
if (!finish_reason.empty()) {
|
||||
json finish_chunk = {
|
||||
{"choices", json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}}})},
|
||||
{"created", t},
|
||||
{"id", completion_id},
|
||||
{"model", modelname},
|
||||
{"object", "chat.completion.chunk"}
|
||||
};
|
||||
streaming_chunks.push_back(finish_chunk);
|
||||
}
|
||||
|
||||
// Return streaming chunks (could be just final chunk if no diffs)
|
||||
if (!streaming_chunks.empty()) {
|
||||
return streaming_chunks;
|
||||
}
|
||||
|
||||
// Fallback to original streaming logic for non-tool calls
|
||||
json choices;
|
||||
|
||||
if (!finish_reason.empty()) {
|
||||
@@ -2812,6 +3043,7 @@ int main(int argc, char ** argv) {
|
||||
// TODO: not great to use extern vars
|
||||
server_log_json = params.log_json;
|
||||
server_verbose = params.verbosity > 0;
|
||||
|
||||
|
||||
// struct that contains llama context and inference
|
||||
server_context ctx_server;
|
||||
|
||||
Reference in New Issue
Block a user