From 24ac2596ef0ccf5b8189a5f3fa4d76e80a1f8cea Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 9 Aug 2025 11:10:19 +0300 Subject: [PATCH] gmp-oss: common --- common/chat-parser.h | 36 ++++++++++----------- common/chat.cpp | 75 +++++++++++++++++++++++++++++--------------- common/chat.h | 25 ++++++++------- 3 files changed, 82 insertions(+), 54 deletions(-) diff --git a/common/chat-parser.h b/common/chat-parser.h index 7c660e53..1e7a3f94 100644 --- a/common/chat-parser.h +++ b/common/chat-parser.h @@ -24,9 +24,9 @@ class common_chat_msg_parser { std::string prelude; std::vector groups; }; - + common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax); - + // Accessors const std::string & input() const { return input_; } size_t pos() const { return pos_; } @@ -42,7 +42,7 @@ class common_chat_msg_parser { } pos_ = pos; } - + void move_back(size_t n) { if (pos_ < n) { throw std::runtime_error("Can't move back that far!"); @@ -56,46 +56,46 @@ class common_chat_msg_parser { // Content manipulation void add_content(const std::string & content); void add_reasoning_content(const std::string & reasoning_content); - + // Tool call manipulation void add_tool_call(const common_chat_tool_call & tool_call); bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments); bool add_tool_call(const json & tool_call); bool add_tool_calls(const json & arr); void clear_tools(); - + // Parsing utilities std::string consume_rest(); bool try_consume_literal(const std::string & literal); void consume_literal(const std::string & literal); bool try_parse_reasoning(const std::string & start_think, const std::string & end_think); - + // Regex-based parsing methods (new) std::optional try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true); find_regex_result consume_regex(const common_regex & regex); std::optional try_consume_regex(const common_regex & regex); - + // Progressive parsing primitives (for Phase 4) std::optional try_find_literal(const std::string & literal); bool consume_spaces(); void set_healing_marker(const std::string & marker); - - + + // Main parsing entry point void parse(); - + // Finishing void finish(); - + // Result extraction common_chat_msg result_and_reset(); - + // Advanced JSON parsing (following original llama.cpp patterns) struct consume_json_result { json value; bool is_partial; }; - + std::optional try_consume_json(); common_json consume_json(); consume_json_result consume_json_with_dumped_args( @@ -112,8 +112,8 @@ private: void parse_kimi_k2_format(); void parse_deepseek_r1_format(); void parse_generic_format(); - - + + // JSON parsing utilities (enhanced streaming support) struct json_parse_result { json value; @@ -121,11 +121,11 @@ private: bool is_partial; std::string healing_marker; }; - + // Partial detection utilities bool detect_partial_function_call(const std::string& content); void handle_partial_detection(); - + // Legacy find_literal for compatibility std::optional try_find_literal_legacy(const std::string & literal); }; @@ -133,4 +133,4 @@ private: // Main parsing function (public API) common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); -// Content-only parsing for fallback scenarios (static internal function) \ No newline at end of file +// Content-only parsing for fallback scenarios (static internal function) diff --git a/common/chat.cpp b/common/chat.cpp index f62c2801..8be08633 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -220,7 +220,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { // Check for the new tools array format first (no DeepSeek markers) auto original_pos = builder.pos(); - + // First, try the tools array format for content like "function\n```json\n{"tools": [...]}" if (builder.try_find_regex(function_regex_simple)) { builder.move_to(original_pos); @@ -231,7 +231,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { // Fall through to try standard DeepSeek patterns } } - + // If tools array format didn't work, try XML-wrapped format builder.move_to(original_pos); try { @@ -240,7 +240,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { } catch (const common_chat_msg_partial_exception&) { // Fall through to try standard DeepSeek patterns } - + // If XML wrapper format didn't work, try standard DeepSeek patterns builder.move_to(original_pos); try { @@ -278,7 +278,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { throw; // Re-throw for partial mode } } - + // Add any remaining content (critical for responses without tool calls) builder.add_content(builder.consume_rest()); } @@ -286,19 +286,19 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { // Parse DeepSeek R1 tools array format following original llama.cpp parse_prefixed_json_tool_call_array pattern static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) { static const common_regex prefix("function\n```json\n"); - - + + if (auto res = builder.try_find_regex(prefix)) { // Parse JSON and manually process tools array to convert arguments to strings auto json_result = builder.try_consume_json(); if (!json_result) { throw common_chat_msg_partial_exception("invalid JSON"); } - - + + // DeepSeek R1 format has "tools" array, manually process each tool if (json_result->json.contains("tools") && json_result->json.at("tools").is_array()) { - + // Manually create tool calls array with string arguments (following original pattern) json tools_with_dumped_args = json::array(); for (const auto& tool : json_result->json.at("tools")) { @@ -310,15 +310,15 @@ static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) { tools_with_dumped_args.push_back(formatted_tool); } } - - + + if (!builder.add_tool_calls(tools_with_dumped_args) || !json_result->healing_marker.marker.empty()) { throw common_chat_msg_partial_exception("incomplete tool call array"); } } else { throw common_chat_msg_partial_exception("tools key not found or not array"); } - + // Consume closing ``` builder.try_consume_regex(common_regex("```")); } else { @@ -326,41 +326,41 @@ static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) { } } -// Parse DeepSeek R1 XML-wrapped format following original Hermes-2-Pro pattern +// Parse DeepSeek R1 XML-wrapped format following original Hermes-2-Pro pattern static void parse_deepseek_r1_xml_wrapped(common_chat_msg_parser & builder) { - + // Pattern for: \nfunctionFunctionName\n```json\n{...}\n```\n static const common_regex xml_pattern( "\\s*" // Opening XML tag - "function([^\\n]+)" // Function name after "function" + "function([^\\n]+)" // Function name after "function" "\\s*```json\\s*" // JSON block start ); - + if (auto res = builder.try_find_regex(xml_pattern)) { - + // Extract function name from capture group std::string function_name = builder.str(res->groups[1]); - + // Parse JSON arguments auto json_result = builder.try_consume_json(); if (!json_result) { throw common_chat_msg_partial_exception("invalid JSON in XML wrapper"); } - - + + // Create single tool call following original pattern json tool_call; tool_call["name"] = function_name; tool_call["arguments"] = json_result->json.dump(); // Convert to string - + json tool_calls_array = json::array(); tool_calls_array.push_back(tool_call); - - + + if (!builder.add_tool_calls(tool_calls_array) || !json_result->healing_marker.marker.empty()) { throw common_chat_msg_partial_exception("incomplete XML wrapped tool call"); } - + // Consume closing ```\n builder.try_consume_regex(common_regex("```\\s*")); } else { @@ -384,6 +384,15 @@ static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) { builder.add_content(kimi_k2::clean_content(builder.input())); } +static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { + // TODO @ngxson : this won't work with --special enabled, we should fix that + builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>"); + if (!builder.syntax().enable_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } +} + // Main parsing dispatch function static void common_chat_parse(common_chat_msg_parser & builder) { switch (builder.syntax().format) { @@ -399,6 +408,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_KIMI_K2: common_chat_parse_kimi_k2(builder); break; + case COMMON_CHAT_FORMAT_GPT_OSS: + common_chat_parse_gpt_oss(builder); + break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } @@ -432,6 +444,19 @@ const char* common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_GENERIC: return "generic"; case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "deepseek_r1"; case COMMON_CHAT_FORMAT_KIMI_K2: return "kimi_k2"; + case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS"; default: return "unknown"; } -} \ No newline at end of file +} + +const char * common_reasoning_format_name(common_reasoning_format format) { + switch (format) { + case COMMON_REASONING_FORMAT_NONE: return "none"; + case COMMON_REASONING_FORMAT_AUTO: return "auto"; + case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek"; + case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy"; + default: + throw std::runtime_error("Unknown reasoning format"); + } +} + diff --git a/common/chat.h b/common/chat.h index 5899ef1a..83e31566 100644 --- a/common/chat.h +++ b/common/chat.h @@ -13,20 +13,20 @@ struct common_chat_templates; struct common_string_range { size_t begin; size_t end; - + common_string_range(size_t begin, size_t end) : begin(begin), end(end) { if (begin > end) { throw std::runtime_error("Invalid range"); } } - + // prevent default ctor common_string_range() = delete; - + bool empty() const { return begin == end; } - + bool operator==(const common_string_range & other) const { return begin == other.begin && end == other.end; } @@ -40,7 +40,7 @@ struct common_chat_tool_call { bool operator==(const common_chat_tool_call & other) const { return name == other.name && arguments == other.arguments && id == other.id; } - + bool operator!=(const common_chat_tool_call & other) const { return !(*this == other); } @@ -65,10 +65,10 @@ struct common_chat_msg { std::string tool_call_id; bool empty() const { - return content.empty() && content_parts.empty() && tool_calls.empty() && + return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); } - + void ensure_tool_call_ids_set(std::vector & ids_cache, const std::function & gen_tool_call_id) { for (auto i = 0u; i < tool_calls.size(); i++) { if (ids_cache.size() <= i) { @@ -91,7 +91,7 @@ struct common_chat_msg { && tool_name == other.tool_name && tool_call_id == other.tool_call_id; } - + bool operator!=(const common_chat_msg & other) const { return !(*this == other); } @@ -110,7 +110,7 @@ struct common_chat_msg_diff { && tool_call_index == other.tool_call_index && tool_call_delta == other.tool_call_delta; } - + bool operator!=(const common_chat_msg_diff & other) const { return !(*this == other); } @@ -132,18 +132,20 @@ enum common_chat_format { COMMON_CHAT_FORMAT_CONTENT_ONLY, COMMON_CHAT_FORMAT_GENERIC, COMMON_CHAT_FORMAT_DEEPSEEK_R1, + COMMON_CHAT_FORMAT_GPT_OSS, COMMON_CHAT_FORMAT_KIMI_K2, // Our custom format (keep last for backward compatibility) }; enum common_reasoning_format { COMMON_REASONING_FORMAT_NONE, + COMMON_REASONING_FORMAT_AUTO, COMMON_REASONING_FORMAT_DEEPSEEK, COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, }; struct common_chat_syntax { common_chat_format format = COMMON_CHAT_FORMAT_KIMI_K2; - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; //COMMON_REASONING_FORMAT_NONE; // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) bool reasoning_in_content = false; bool thinking_forced_open = false; @@ -165,11 +167,12 @@ class common_chat_msg_partial_exception : public std::runtime_error { // Format detection from chat template common_chat_format common_chat_format_detect(const std::string & chat_template); const char* common_chat_format_name(common_chat_format format); +const char* common_reasoning_format_name(common_reasoning_format format); // Main parsing function (entry point for original llama.cpp compatibility) common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); -// Forward declare parser class +// Forward declare parser class class common_chat_msg_parser; // Format-specific parsing functions (accessible from chat-parser)