Fix split mode graph with Qwen3.5-MoE/Qwen3-Next hybryd inference (#1368)

This commit is contained in:
Kawrakow
2026-03-06 07:26:15 +01:00
committed by GitHub
parent 3208660d20
commit fa0c29843d
2 changed files with 8 additions and 12 deletions

View File

@@ -1304,7 +1304,7 @@ llm_expert_gating_func_type gating_op,
if (up_shexp && gate_shexp && down_shexp) {
if (split_up_shexp) {
std::vector<ggml_tensor *> results; results.reserve(split_up_shexp->n_device);
std::vector<ggml_tensor *> results(split_up_shexp->n_device, nullptr);
GGML_ASSERT(!split_up_b_shexp || split_up_b_shexp->n_device == split_up_shexp->n_device);
GGML_ASSERT(!split_gate_b_shexp || split_gate_b_shexp->n_device == split_up_shexp->n_device);
GGML_ASSERT(!split_down_b_shexp || split_down_b_shexp->n_device == split_up_shexp->n_device);
@@ -1335,8 +1335,7 @@ llm_expert_gating_func_type gating_op,
split_up_shexp->splits[id], split_up_b_shexp ? split_up_b_shexp->splits[id] : nullptr, nullptr,
split_gate_shexp->splits[id], split_gate_b_shexp ? split_gate_b_shexp->splits[id] : nullptr, nullptr,
split_down_shexp->splits[id], !down_bias_added && split_down_b_shexp ? split_down_b_shexp->splits[id] : nullptr, nullptr,
nullptr, type_op_shexp, LLM_FFN_PAR, cb, il, graph, false, false,
id == id_add_routed ? routed_out : nullptr);
nullptr, type_op_shexp, LLM_FFN_PAR, cb, il, graph, false, false, nullptr);
cb(shared_out, "ffn_shexp_out", il_cb);
if (shexp_gate) {
auto split_shexp_gate = (ggml_split_tensor_t *)shexp_gate->extra;
@@ -1350,11 +1349,16 @@ llm_expert_gating_func_type gating_op,
}
cb(shared_out, "ffn_shexp_gated", il_cb);
}
if (id == id_add_routed) {
shared_out = ggml_add(ctx, shared_out, routed_out);
cb(shared_out, "ffn_shared_routed_added", il);
}
if (shared_out->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) {
shared_out = ggml_cast(ctx, shared_out, lctx.cparams.reduce_type);
}
ggml_build_forward_expand(graph, shared_out);
down_bias_added = true;
results.push_back(shared_out);
results[id] = shared_out;
}
GGML_ASSERT(!results.empty());
cur = ggml_reduce(ctx, results.data(), split_up_shexp->n_device, GGML_OP_ADD);

View File

@@ -1981,14 +1981,6 @@ static bool llm_load_tensors(
LLAMA_LOG_WARN("================================================================\n\n");
max_gpu = 4;
}
else if (llama_model_has_recurrent(&model) && model.has_tensor_overrides()) {
LLAMA_LOG_WARN("\n================================================================\n");
LLAMA_LOG_WARN("Split mode 'graph' for recurrent/hybrid models is currently\n");
LLAMA_LOG_WARN("disabled when using tensor overrides\n");
LLAMA_LOG_WARN(" => changing split mode to 'layer'\n");
LLAMA_LOG_WARN("=======================================================\n\n");
split_mode = LLAMA_SPLIT_MODE_LAYER;
}
}
}