diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index e4e5d456..cc39dca0 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1304,7 +1304,7 @@ llm_expert_gating_func_type gating_op, if (up_shexp && gate_shexp && down_shexp) { if (split_up_shexp) { - std::vector results; results.reserve(split_up_shexp->n_device); + std::vector results(split_up_shexp->n_device, nullptr); GGML_ASSERT(!split_up_b_shexp || split_up_b_shexp->n_device == split_up_shexp->n_device); GGML_ASSERT(!split_gate_b_shexp || split_gate_b_shexp->n_device == split_up_shexp->n_device); GGML_ASSERT(!split_down_b_shexp || split_down_b_shexp->n_device == split_up_shexp->n_device); @@ -1335,8 +1335,7 @@ llm_expert_gating_func_type gating_op, split_up_shexp->splits[id], split_up_b_shexp ? split_up_b_shexp->splits[id] : nullptr, nullptr, split_gate_shexp->splits[id], split_gate_b_shexp ? split_gate_b_shexp->splits[id] : nullptr, nullptr, split_down_shexp->splits[id], !down_bias_added && split_down_b_shexp ? split_down_b_shexp->splits[id] : nullptr, nullptr, - nullptr, type_op_shexp, LLM_FFN_PAR, cb, il, graph, false, false, - id == id_add_routed ? routed_out : nullptr); + nullptr, type_op_shexp, LLM_FFN_PAR, cb, il, graph, false, false, nullptr); cb(shared_out, "ffn_shexp_out", il_cb); if (shexp_gate) { auto split_shexp_gate = (ggml_split_tensor_t *)shexp_gate->extra; @@ -1350,11 +1349,16 @@ llm_expert_gating_func_type gating_op, } cb(shared_out, "ffn_shexp_gated", il_cb); } + if (id == id_add_routed) { + shared_out = ggml_add(ctx, shared_out, routed_out); + cb(shared_out, "ffn_shared_routed_added", il); + } if (shared_out->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) { shared_out = ggml_cast(ctx, shared_out, lctx.cparams.reduce_type); } + ggml_build_forward_expand(graph, shared_out); down_bias_added = true; - results.push_back(shared_out); + results[id] = shared_out; } GGML_ASSERT(!results.empty()); cur = ggml_reduce(ctx, results.data(), split_up_shexp->n_device, GGML_OP_ADD); diff --git a/src/llama.cpp b/src/llama.cpp index f3daeeed..825800bb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1981,14 +1981,6 @@ static bool llm_load_tensors( LLAMA_LOG_WARN("================================================================\n\n"); max_gpu = 4; } - else if (llama_model_has_recurrent(&model) && model.has_tensor_overrides()) { - LLAMA_LOG_WARN("\n================================================================\n"); - LLAMA_LOG_WARN("Split mode 'graph' for recurrent/hybrid models is currently\n"); - LLAMA_LOG_WARN("disabled when using tensor overrides\n"); - LLAMA_LOG_WARN(" => changing split mode to 'layer'\n"); - LLAMA_LOG_WARN("=======================================================\n\n"); - split_mode = LLAMA_SPLIT_MODE_LAYER; - } } }