Enable CUDA graphs for MoE models + GPT-OSS support (#689)

* gmp-oss: common

* gpt-oss: attnetion sinks, swiglu_oai

* gpt-oss: WIP llama

Model loads and runs (CPU only), but PPL is much to high
(~1500 for 1st batch vs ~200 in mainline).
Is it because of SWA, because of vocab, or did I introduce a bug somewhere?

* gpt-oss: CPU seems to be working

It was the SWA thta was missing in the previous commit.

There are issues with EOG tokens, so this still needs to be added.

* CUDA: ADD_ID

Just a copy from mainline

* gpt-oss: Seems to be working on CUDA

* gpt-oss: add sinks to the attn-vec kernels

* CUDA: add head size of 64 to new mma

Haven't turned it on yet, but observe slightly better PP and slightly
worse TG performance with that.

* gpt-oss: add ability to use -fmoe (only CUDA for now)

* Move row sums to the write place

* Add sinks to iqk flash attention

* gpt_oss: Implement -fmoe on the CPU

* Simdify swiglu_oai

Turning it off for now as performance becomes more variable,
so perhaps I'm running into thermal trottling imore often
because of making the CPU work too hard.

* llama: factor out model loader

* Builds successfully

* It runs, but mmap does not work

* Fix llama_mmap so mmap works

* Minor

* Fix CUDA after latest changes

* Attempt to use CUDA graphs with MoE models - not working

* CUDA graphs WIP - still not working

* CUDA graphs - seems to be working

Likely not all MLA variants are working.
I no longer remember why I added the q8_0 cpy that
transposes the tensor, but if really needed, this is now
missing. Also missing is q6_0.

* Make q8_0 cache work for DeepSeek models with CUDA graphs

* cuda: cpy for q6_0

* Fix llama_mmap on non-Linux platforms

* Adding forgotten file

* Iterating on Windows build failures

* cuda: re-add q8_0 -> q8_0 transpose

so mla = 2 can be used with CUDA graphs and q8_0 cache.

* Disable graphs without -fmoe

* Minor

* Turn graphs on by default

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-08-15 09:18:07 +03:00
committed by GitHub
parent e082df47f2
commit fc06bc9d27
56 changed files with 8720 additions and 5115 deletions

View File

@@ -220,7 +220,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
// Check for the new tools array format first (no DeepSeek markers)
auto original_pos = builder.pos();
// First, try the tools array format for content like "function\n```json\n{"tools": [...]}"
if (builder.try_find_regex(function_regex_simple)) {
builder.move_to(original_pos);
@@ -231,7 +231,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
// Fall through to try standard DeepSeek patterns
}
}
// If tools array format didn't work, try XML-wrapped format
builder.move_to(original_pos);
try {
@@ -240,7 +240,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
} catch (const common_chat_msg_partial_exception&) {
// Fall through to try standard DeepSeek patterns
}
// If XML wrapper format didn't work, try standard DeepSeek patterns
builder.move_to(original_pos);
try {
@@ -278,7 +278,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
throw; // Re-throw for partial mode
}
}
// Add any remaining content (critical for responses without tool calls)
builder.add_content(builder.consume_rest());
}
@@ -286,19 +286,19 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
// Parse DeepSeek R1 tools array format following original llama.cpp parse_prefixed_json_tool_call_array pattern
static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) {
static const common_regex prefix("function\n```json\n");
if (auto res = builder.try_find_regex(prefix)) {
// Parse JSON and manually process tools array to convert arguments to strings
auto json_result = builder.try_consume_json();
if (!json_result) {
throw common_chat_msg_partial_exception("invalid JSON");
}
// DeepSeek R1 format has "tools" array, manually process each tool
if (json_result->json.contains("tools") && json_result->json.at("tools").is_array()) {
// Manually create tool calls array with string arguments (following original pattern)
json tools_with_dumped_args = json::array();
for (const auto& tool : json_result->json.at("tools")) {
@@ -310,15 +310,15 @@ static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) {
tools_with_dumped_args.push_back(formatted_tool);
}
}
if (!builder.add_tool_calls(tools_with_dumped_args) || !json_result->healing_marker.marker.empty()) {
throw common_chat_msg_partial_exception("incomplete tool call array");
}
} else {
throw common_chat_msg_partial_exception("tools key not found or not array");
}
// Consume closing ```
builder.try_consume_regex(common_regex("```"));
} else {
@@ -326,41 +326,41 @@ static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) {
}
}
// Parse DeepSeek R1 XML-wrapped format following original Hermes-2-Pro pattern
// Parse DeepSeek R1 XML-wrapped format following original Hermes-2-Pro pattern
static void parse_deepseek_r1_xml_wrapped(common_chat_msg_parser & builder) {
// Pattern for: <tool_call>\nfunction</think>FunctionName\n```json\n{...}\n```\n</tool_call>
static const common_regex xml_pattern(
"<tool_call>\\s*" // Opening XML tag
"function</think>([^\\n]+)" // Function name after "function</think>"
"function</think>([^\\n]+)" // Function name after "function</think>"
"\\s*```json\\s*" // JSON block start
);
if (auto res = builder.try_find_regex(xml_pattern)) {
// Extract function name from capture group
std::string function_name = builder.str(res->groups[1]);
// Parse JSON arguments
auto json_result = builder.try_consume_json();
if (!json_result) {
throw common_chat_msg_partial_exception("invalid JSON in XML wrapper");
}
// Create single tool call following original pattern
json tool_call;
tool_call["name"] = function_name;
tool_call["arguments"] = json_result->json.dump(); // Convert to string
json tool_calls_array = json::array();
tool_calls_array.push_back(tool_call);
if (!builder.add_tool_calls(tool_calls_array) || !json_result->healing_marker.marker.empty()) {
throw common_chat_msg_partial_exception("incomplete XML wrapped tool call");
}
// Consume closing ```\n</tool_call>
builder.try_consume_regex(common_regex("```\\s*</tool_call>"));
} else {
@@ -384,6 +384,15 @@ static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
builder.add_content(kimi_k2::clean_content(builder.input()));
}
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
// TODO @ngxson : this won't work with --special enabled, we should fix that
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
if (!builder.syntax().enable_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
}
// Main parsing dispatch function
static void common_chat_parse(common_chat_msg_parser & builder) {
switch (builder.syntax().format) {
@@ -399,6 +408,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_KIMI_K2:
common_chat_parse_kimi_k2(builder);
break;
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
@@ -432,6 +444,19 @@ const char* common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_GENERIC: return "generic";
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "deepseek_r1";
case COMMON_CHAT_FORMAT_KIMI_K2: return "kimi_k2";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
default: return "unknown";
}
}
}
const char * common_reasoning_format_name(common_reasoning_format format) {
switch (format) {
case COMMON_REASONING_FORMAT_NONE: return "none";
case COMMON_REASONING_FORMAT_AUTO: return "auto";
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
default:
throw std::runtime_error("Unknown reasoning format");
}
}