mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-21 06:59:21 +00:00
Fix hybrid graph parallel + muge (#1426)
This commit is contained in:
@@ -4666,7 +4666,9 @@ static void llama_repack_up_gate_exps(llama_context & lctx) {
|
||||
auto & model = lctx.model;
|
||||
bool needs_repack = false;
|
||||
for (auto & l : model.layers) {
|
||||
if (l.ffn_up_gate_exps && l.ffn_up_exps && l.ffn_gate_exps) {
|
||||
if (l.ffn_up_gate_exps && l.ffn_up_exps && l.ffn_gate_exps &&
|
||||
ggml_backend_buffer_is_host(l.ffn_up_gate_exps->buffer) &&
|
||||
ggml_backend_buffer_is_host(l.ffn_up_exps->buffer) && ggml_backend_buffer_is_host(l.ffn_gate_exps->buffer)) {
|
||||
needs_repack = true; break;
|
||||
}
|
||||
}
|
||||
@@ -4675,7 +4677,9 @@ static void llama_repack_up_gate_exps(llama_context & lctx) {
|
||||
std::vector<char> aux_buffer_up, aux_buffer_gate, aux_buffer_up_gate;
|
||||
for (int il = 0; il < int(model.layers.size()); ++il) {
|
||||
auto & l = model.layers[il];
|
||||
if (l.ffn_up_gate_exps && l.ffn_up_exps && l.ffn_gate_exps) {
|
||||
if (l.ffn_up_gate_exps && l.ffn_up_exps && l.ffn_gate_exps &&
|
||||
ggml_backend_buffer_is_host(l.ffn_up_gate_exps->buffer) &&
|
||||
ggml_backend_buffer_is_host(l.ffn_up_exps->buffer) && ggml_backend_buffer_is_host(l.ffn_gate_exps->buffer)) {
|
||||
GGML_ASSERT(l.ffn_up_gate_exps->type == l.ffn_up_exps->type && l.ffn_up_gate_exps->type == l.ffn_gate_exps->type);
|
||||
GGML_ASSERT(l.ffn_up_gate_exps->ne[0] == l.ffn_up_exps->ne[0] && l.ffn_up_gate_exps->ne[0] == l.ffn_gate_exps->ne[0]);
|
||||
GGML_ASSERT(l.ffn_up_gate_exps->ne[2] == l.ffn_up_exps->ne[2] && l.ffn_up_gate_exps->ne[2] == l.ffn_gate_exps->ne[2]);
|
||||
@@ -5209,9 +5213,7 @@ struct llama_context * llama_init_from_model(
|
||||
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched));
|
||||
}
|
||||
|
||||
if (ctx->model.split_mode != LLAMA_SPLIT_MODE_GRAPH) {
|
||||
llama_repack_up_gate_exps(*ctx);
|
||||
}
|
||||
llama_repack_up_gate_exps(*ctx);
|
||||
|
||||
// build worst-case graph
|
||||
int n_past = cparams.n_ctx - n_tokens;
|
||||
|
||||
Reference in New Issue
Block a user