mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-26 09:29:27 +00:00
Graph parallel: the next generation (#1080)
* WIP: absorb adding input into std_attn and std_ffn * WIP: NCCL infra * WIP: add reduce and fake_cpy ops * WIP * WIP: graph appears to work, layer is broken * WIP: Qwen3-MoE works with graph, layer still broken * WIP: GLM-4.5 graph works * WIP: fix sm layer (dense) * WIP: fix sm layer (MoE) * WIP: fast PP with bespoke 4-GPU NCCL I guess, I'm not using NCCL the right way as PP is very low with a single communicator group for 3 or more GPUs. But if I create 4 communicator groups for pairs of GPUs (0,1, 2,3, 0,2, 1,3) and use that, PP is fast: I'm hitting 1500 t/s for L3-70B on the 4x3090 system, which is ~20% better than the previous sm graph without NCCL. But that cannot be the solution (I cannot be creating pairwise communicators and associated logic for every possible number of GPUs). * WIP: Cohere2 * Explicitely set device * Bespoke 3-GPU case * WIP * Do not repeat get_rows multiple times * Fix 3 GPUs * OK, let's leave it in * Implement the reduce op without NCCL available * Be able to build without NCCL cmake -DGGML_NCCL=OFF disables it * Make --max-gpu work again * Slightly better for 4 GPUs without NCCL * Cleanup --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -620,6 +620,20 @@ ggml_tensor * llm_build_context::llm_build_norm(
|
||||
return cur;
|
||||
}
|
||||
|
||||
static ggml_tensor * get_input_tensor_sm_graph(ggml_tensor * input, int id) {
|
||||
auto cur = input;
|
||||
if (input->op == GGML_OP_REDUCE) {
|
||||
auto view_src = input->view_src;
|
||||
GGML_ASSERT(view_src);
|
||||
cur = input->src[id];
|
||||
if (cur == view_src || !cur) {
|
||||
//printf("%s: Setting input to %s for id = %d\n", __func__, view_src->name, id);
|
||||
cur = input;
|
||||
}
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_context::llm_build_ffn(
|
||||
ggml_context * ctx,
|
||||
llama_context & lctx,
|
||||
@@ -637,19 +651,21 @@ ggml_tensor * llm_build_context::llm_build_ffn(
|
||||
ggml_tensor * act_scales,
|
||||
llm_ffn_op_type type_op,
|
||||
llm_ffn_gate_type type_gate,
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph) {
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input,
|
||||
bool is_norm, ggml_tensor * add_extra) {
|
||||
|
||||
if (!up_b && !up_s && !gate_b && !gate_s && !down_b && !down_s &&
|
||||
up->extra && gate->extra && down->extra && type_gate == LLM_FFN_PAR &&
|
||||
(type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) {
|
||||
//printf("%s: %s\n", __func__, ggml_op_name(input->op));
|
||||
auto unary_op = type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU :
|
||||
type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU;
|
||||
auto u = (ggml_split_tensor_t *)up->extra;
|
||||
auto g = (ggml_split_tensor_t *)gate->extra;
|
||||
auto d = (ggml_split_tensor_t *)down->extra;
|
||||
GGML_ASSERT(u->n_device == g->n_device && u->n_device == d->n_device);
|
||||
std::vector<ggml_tensor *> ffn;
|
||||
ffn.reserve(u->n_device);
|
||||
std::vector<ggml_tensor *> ffn(u->n_device, nullptr);
|
||||
int id_last = -1;
|
||||
for (int id = 0; id < u->n_device; ++id) {
|
||||
int il_cb = 1000*(id+1) + il;
|
||||
auto split_u = u->splits[id];
|
||||
@@ -657,15 +673,21 @@ ggml_tensor * llm_build_context::llm_build_ffn(
|
||||
auto split_d = d->splits[id];
|
||||
GGML_ASSERT((!split_u && !split_g && !split_d) || (split_u && split_g && split_d));
|
||||
if (!split_u) continue;
|
||||
auto cur = input;
|
||||
auto cur = get_input_tensor_sm_graph(input, id);
|
||||
if (ffn_norm && ffn_norm->extra) {
|
||||
auto norm = (ggml_split_tensor_t *)ffn_norm->extra;
|
||||
GGML_ASSERT(norm->splits[id]);
|
||||
cur = llm_build_norm(ctx, input, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
|
||||
if (is_norm) {
|
||||
cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM, cb, il);
|
||||
GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM);
|
||||
cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff;
|
||||
} else {
|
||||
cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
|
||||
}
|
||||
cb(cur, "ffn_inp_normed", il_cb);
|
||||
}
|
||||
else if (input->type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx, input, GGML_TYPE_F32);
|
||||
else if (cur->type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
|
||||
}
|
||||
cur = ggml_fused_up_gate(ctx, split_u, split_g, cur, unary_op);
|
||||
cb(cur, "ffn_up_gate", il_cb);
|
||||
@@ -681,32 +703,31 @@ ggml_tensor * llm_build_context::llm_build_ffn(
|
||||
if (graph) {
|
||||
ggml_build_forward_expand(graph, cur);
|
||||
}
|
||||
ffn.push_back(cur);
|
||||
ffn[id] = cur;
|
||||
id_last = id;
|
||||
}
|
||||
if (ffn.size() == 1) return ffn.front();
|
||||
auto cur = ggml_add(ctx, ffn[0], ffn[1]);
|
||||
cb(cur, "combine_ffn", il);
|
||||
cur->op_params[0] = 0xff;
|
||||
for (int id = 2; id < int(ffn.size()); ++id) {
|
||||
cur = ggml_add(ctx, cur, ffn[id]);
|
||||
cb(cur, "combine_ffn", il);
|
||||
GGML_ASSERT(id_last >= 0);
|
||||
if (add_input) {
|
||||
ffn[id_last] = ggml_add(ctx, ffn[id_last], input);
|
||||
cb(ffn[id_last], "ffn_with_inp", il);
|
||||
}
|
||||
if (ffn.size() > 2) {
|
||||
cur->op_params[0] = 0xff;
|
||||
if (add_extra) {
|
||||
ffn[id_last] = ggml_add(ctx, ffn[id_last], add_extra);
|
||||
cb(ffn[id_last], "ffn_with_inp", il);
|
||||
}
|
||||
//if (cur->type != GGML_TYPE_F32) {
|
||||
// cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
|
||||
//}
|
||||
|
||||
auto cur = ggml_reduce(ctx, ffn.data(), u->n_device, GGML_OP_ADD);
|
||||
cb(cur, "ffn_combined", il);
|
||||
ggml_build_forward_expand(graph, cur);
|
||||
return cur;
|
||||
}
|
||||
|
||||
auto cur = input;
|
||||
if (ffn_norm) {
|
||||
input = llm_build_norm(ctx, input, lctx.model.hparams, ffn_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cur = llm_build_norm(ctx, cur, lctx.model.hparams, ffn_norm, NULL, is_norm ? LLM_NORM : LLM_NORM_RMS, cb, il);
|
||||
cb(input, "ffn_norm", il);
|
||||
}
|
||||
else if (input->type != GGML_TYPE_F32) {
|
||||
input = ggml_cast(ctx, input, GGML_TYPE_F32);
|
||||
if (cur->type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
|
||||
}
|
||||
|
||||
if (lctx.cparams.fused_up_gate &&
|
||||
@@ -714,7 +735,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
|
||||
(type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) {
|
||||
auto unary_op = type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU :
|
||||
type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU;
|
||||
auto cur = ggml_fused_up_gate(ctx, up, gate, input, unary_op);
|
||||
cur = ggml_fused_up_gate(ctx, up, gate, cur, unary_op);
|
||||
cb(cur, "ffn_up_gate", il);
|
||||
if (down) {
|
||||
cur = llm_build_lora_mm(lctx, ctx, down, cur);
|
||||
@@ -733,10 +754,18 @@ ggml_tensor * llm_build_context::llm_build_ffn(
|
||||
cur = ggml_mul(ctx, cur, down_s);
|
||||
cb(cur, "ffn_down_s", il);
|
||||
}
|
||||
if (add_input) {
|
||||
cur = ggml_add(ctx, cur, input);
|
||||
cb(cur, "ffn_out_with_inp", il);
|
||||
}
|
||||
if (add_extra) {
|
||||
cur = ggml_add(ctx, cur, add_extra);
|
||||
cb(cur, "ffn_out_with_inp", il);
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, input) : input;
|
||||
struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, cur) : cur;
|
||||
cb(tmp, "ffn_up", il);
|
||||
|
||||
if (up_b) {
|
||||
@@ -749,7 +778,6 @@ ggml_tensor * llm_build_context::llm_build_ffn(
|
||||
cb(tmp, "ffn_up_s", il);
|
||||
}
|
||||
|
||||
auto cur = input;
|
||||
if (gate) {
|
||||
switch (type_gate) {
|
||||
case LLM_FFN_SEQ:
|
||||
@@ -849,6 +877,15 @@ ggml_tensor * llm_build_context::llm_build_ffn(
|
||||
cb(cur, "ffn_down_s", il);
|
||||
}
|
||||
|
||||
if (add_input) {
|
||||
cur = ggml_add(ctx, cur, input);
|
||||
cb(cur, "ffn_out_with_inp", il);
|
||||
}
|
||||
if (add_extra) {
|
||||
cur = ggml_add(ctx, cur, add_extra);
|
||||
cb(cur, "ffn_out_with_inp", il);
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
@@ -868,7 +905,9 @@ ggml_tensor * llm_build_context::llm_build_moe_ffn(
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
llm_expert_gating_func_type gating_op,
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph) {
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input) {
|
||||
|
||||
auto input = cur;
|
||||
|
||||
int64_t n_embd = cur->ne[0];
|
||||
int64_t n_tokens = cur->ne[1];
|
||||
@@ -1040,20 +1079,30 @@ llm_expert_gating_func_type gating_op,
|
||||
if (lctx.cparams.fused_mmad) {
|
||||
experts = ggml_mul_multi_add(ctx, experts, weights);
|
||||
cb(experts, "ffn_moe_weighted", il);
|
||||
if (add_input) {
|
||||
experts = ggml_add(ctx, experts, input);
|
||||
cb(experts, "ffn_out_with_inp", il);
|
||||
}
|
||||
return experts;
|
||||
}
|
||||
experts = ggml_mul(ctx, experts, weights);
|
||||
cb(experts, "ffn_moe_weighted", il);
|
||||
}
|
||||
|
||||
ggml_tensor * result;
|
||||
if (n_expert_used == 1) {
|
||||
return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0));
|
||||
result = ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0));
|
||||
}
|
||||
if (n_expert_used == 2) {
|
||||
return ggml_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0),
|
||||
result = ggml_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0),
|
||||
ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], experts->nb[1]));
|
||||
}
|
||||
return ggml_multi_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0), n_expert_used);
|
||||
result = ggml_multi_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0), n_expert_used);
|
||||
if (add_input) {
|
||||
cb(result, "ffn_out", il);
|
||||
result = ggml_add(ctx, result, input);
|
||||
}
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
@@ -1076,7 +1125,7 @@ ggml_tensor * llm_build_context::llm_build_std_moe_ffn(ggml_context * ctx, llama
|
||||
float w_scale,
|
||||
llm_expert_gating_func_type gating_op,
|
||||
llm_ffn_op_type type_op_shexp,
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph) {
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input) {
|
||||
|
||||
auto split_up_exps = (ggml_split_tensor_t *)up_exps->extra;
|
||||
auto split_gate_exps = (ggml_split_tensor_t *)gate_exps->extra;
|
||||
@@ -1092,10 +1141,10 @@ llm_expert_gating_func_type gating_op,
|
||||
if (ffn_norm) {
|
||||
auto the_ffn_norm = ffn_norm->extra ? ((ggml_split_tensor_t *)ffn_norm->extra)->splits[lctx.model.main_gpu] : ffn_norm;
|
||||
GGML_ASSERT(the_ffn_norm);
|
||||
cur = llm_build_norm(ctx, input, lctx.model.hparams, the_ffn_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cur = llm_build_norm(ctx, cur, lctx.model.hparams, the_ffn_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_inp_normed", il);
|
||||
}
|
||||
else if (cur->type != GGML_TYPE_F32) {
|
||||
if (cur->type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
|
||||
}
|
||||
auto the_gate_inp = gate_inp->extra ? ((ggml_split_tensor_t *)gate_inp->extra)->splits[lctx.model.main_gpu] : gate_inp;
|
||||
@@ -1110,8 +1159,12 @@ llm_expert_gating_func_type gating_op,
|
||||
the_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
type_op, norm_w, scale_w, w_scale,
|
||||
gating_op, cb, il, graph);
|
||||
gating_op, cb, il, graph, false);
|
||||
cb(routed_out, "routed_out", il);
|
||||
if (add_input) {
|
||||
routed_out = ggml_add(ctx, routed_out, input);
|
||||
cb(routed_out, "routed_out_with_inp", il);
|
||||
}
|
||||
ggml_build_forward_expand(graph, routed_out);
|
||||
|
||||
if (up_shexp && gate_shexp && down_shexp) {
|
||||
@@ -1176,26 +1229,27 @@ llm_expert_gating_func_type gating_op,
|
||||
}
|
||||
GGML_ASSERT(split_up_exps && split_gate_exps && split_down_exps);
|
||||
GGML_ASSERT(split_up_exps->n_device == split_gate_exps->n_device && split_up_exps->n_device == split_down_exps->n_device);
|
||||
std::vector<ggml_tensor *> results; results.reserve(split_up_exps->n_device);
|
||||
std::vector<ggml_tensor *> results(split_up_exps->n_device, nullptr);
|
||||
GGML_ASSERT((!split_up_shexp && !split_gate_shexp && !split_down_shexp) ||
|
||||
( split_up_shexp && split_gate_shexp && split_down_shexp));
|
||||
auto split_gate_inp = (ggml_split_tensor_t *)gate_inp->extra;
|
||||
GGML_ASSERT(split_gate_inp && split_gate_inp->n_device == split_up_exps->n_device);
|
||||
auto split_exp_probs_b = exp_probs_b ? (ggml_split_tensor_t *)exp_probs_b->extra : nullptr;
|
||||
GGML_ASSERT(!split_exp_probs_b || split_exp_probs_b->n_device == split_up_exps->n_device);
|
||||
int last_id = -1;
|
||||
for (int id = 0; id < split_up_exps->n_device; ++id) {
|
||||
GGML_ASSERT((split_up_exps->splits[id] && split_gate_exps->splits[id] && split_down_exps->splits[id]) ||
|
||||
(!split_up_exps->splits[id] && !split_gate_exps->splits[id] && !split_down_exps->splits[id]));
|
||||
if (!split_up_exps->splits[id]) continue;
|
||||
int il_cb = 1000*(id + 1) + il;
|
||||
auto cur = input;
|
||||
auto cur = get_input_tensor_sm_graph(input, id);
|
||||
if (ffn_norm) {
|
||||
auto split_ffn_norm = (ggml_split_tensor_t *)ffn_norm->extra;
|
||||
GGML_ASSERT(split_ffn_norm && split_ffn_norm->n_device == split_up_exps->n_device);
|
||||
cur = llm_build_norm(ctx, input, lctx.model.hparams, split_ffn_norm->splits[id], nullptr, LLM_NORM_RMS, cb, il);
|
||||
cur = llm_build_norm(ctx, cur, lctx.model.hparams, split_ffn_norm->splits[id], nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_inp_normed", il_cb);
|
||||
}
|
||||
else if (cur->type != GGML_TYPE_F32) {
|
||||
if (cur->type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
|
||||
}
|
||||
auto routed_out = llm_build_moe_ffn(ctx, lctx, cur,
|
||||
@@ -1206,7 +1260,7 @@ llm_expert_gating_func_type gating_op,
|
||||
split_exp_probs_b ? split_exp_probs_b->splits[id] : nullptr,
|
||||
n_expert, n_expert_used,
|
||||
type_op, norm_w, scale_w, w_scale,
|
||||
gating_op, cb, il, graph);
|
||||
gating_op, cb, il, graph, false);
|
||||
cb(routed_out, "routed_out", il_cb);
|
||||
|
||||
if (split_up_shexp) {
|
||||
@@ -1229,19 +1283,20 @@ llm_expert_gating_func_type gating_op,
|
||||
cur = ggml_cast(ctx, cur, GGML_TYPE_F16);
|
||||
cb(cur, "ffn_out_f16", il_cb);
|
||||
}
|
||||
ggml_build_forward_expand(graph, routed_out);
|
||||
results.push_back(cur);
|
||||
ggml_build_forward_expand(graph, cur);
|
||||
results[id] = cur;
|
||||
last_id = id;
|
||||
}
|
||||
GGML_ASSERT(last_id >= 0);
|
||||
if (add_input) {
|
||||
results[last_id] = ggml_add(ctx, results[last_id], input);
|
||||
cb(results[last_id], "ffn_inp_added", il);
|
||||
}
|
||||
GGML_ASSERT(!results.empty());
|
||||
if (results.size() == 1) return results.front();
|
||||
|
||||
auto cur = ggml_add(ctx, results[0], results[1]);
|
||||
cur->op_params[0] = 0xff;
|
||||
cb(cur, "ffn_combined", il);
|
||||
for (int id = 2; id < int(results.size()); ++id) {
|
||||
cur = ggml_add(ctx, cur, results[id]);
|
||||
cb(cur, "ffn_combined", il);
|
||||
}
|
||||
auto cur = ggml_reduce(ctx, results.data(), split_up_exps->n_device, GGML_OP_ADD);
|
||||
cb(cur, "moe_ffn_combined", il);
|
||||
ggml_build_forward_expand(graph, cur);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
@@ -1754,7 +1809,7 @@ ggml_cgraph * llm_build_context::build_llama() {
|
||||
// self-attention
|
||||
if (use_rope) {
|
||||
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr,
|
||||
this_KQ_mask, nullptr, nullptr, kq_scale, hparams.f_attention_scale, this_n_swa, il);
|
||||
this_KQ_mask, nullptr, nullptr, kq_scale, hparams.f_attention_scale, this_n_swa, il, true, false, true);
|
||||
}
|
||||
else {
|
||||
|
||||
@@ -1801,15 +1856,18 @@ ggml_cgraph * llm_build_context::build_llama() {
|
||||
Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr,
|
||||
this_n_swa);
|
||||
}
|
||||
//printf("%s: attn result for layer %d is %s, %s\n", __func__, il, cur->name, ggml_op_name(cur->op));
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
n_tokens = n_outputs;
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
cb(cur, "last_attn", il);
|
||||
cb(inpSA, "last_ffn_inp", il);
|
||||
if (!use_rope) {
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
cb(inpSA, "last_ffn_inp", il);
|
||||
}
|
||||
}
|
||||
|
||||
// For Granite architecture
|
||||
@@ -1818,8 +1876,13 @@ ggml_cgraph * llm_build_context::build_llama() {
|
||||
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
ggml_tensor * ffn_inp;
|
||||
if (use_rope) {
|
||||
ffn_inp = cur;
|
||||
} else {
|
||||
ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
}
|
||||
|
||||
// feed-forward network
|
||||
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||
@@ -1829,7 +1892,7 @@ ggml_cgraph * llm_build_context::build_llama() {
|
||||
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf);
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else if (model.arch == LLM_ARCH_LLAMA4) {
|
||||
// llama4 MoE
|
||||
@@ -1846,7 +1909,7 @@ ggml_cgraph * llm_build_context::build_llama() {
|
||||
LLM_FFN_SILU, false,
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SIGMOID,
|
||||
cb, il, gf);
|
||||
cb, il, gf, true);
|
||||
|
||||
// Shared experts
|
||||
ggml_tensor * shexp_out = llm_build_ffn(ctx0, lctx, nullptr, ffn_inp_normed,
|
||||
@@ -1875,9 +1938,10 @@ ggml_cgraph * llm_build_context::build_llama() {
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb, il, gf);
|
||||
cb, il, gf, true);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
}
|
||||
//printf("%s: ffn result for layer %d is %s, %s\n", __func__, il, cur->name, ggml_op_name(cur->op));
|
||||
|
||||
// For Granite architecture
|
||||
if (hparams.f_residual_scale) {
|
||||
@@ -1885,8 +1949,8 @@ ggml_cgraph * llm_build_context::build_llama() {
|
||||
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
//cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
//cb(cur, "ffn_out", il);
|
||||
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
@@ -3933,23 +3997,26 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
//struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
//cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
//cb(cur, "attn_norm", il);
|
||||
|
||||
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, il);
|
||||
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0,
|
||||
il, true, false, true);
|
||||
//printf("%s: attn = %s(%s)\n", __func__, cur->name, ggml_op_name(cur->op));
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
//inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
auto ffn_inp = cur;
|
||||
//struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
//cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
|
||||
model.layers[il].ffn_gate_inp, nullptr,
|
||||
@@ -3963,9 +4030,11 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true, false, 0.0f,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
LLM_FFN_SILU, cb, il, gf);
|
||||
LLM_FFN_SILU, cb, il, gf, true);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
//printf("%s: ffn = %s(%s)\n", __func__, cur->name, ggml_op_name(cur->op));
|
||||
|
||||
//cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
@@ -6818,7 +6887,7 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
|
||||
|
||||
// self-attention
|
||||
if (rope_cache == nullptr) {
|
||||
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask, nullptr, nullptr, kq_scale, 0.0f, 0, il);
|
||||
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask, nullptr, nullptr, kq_scale, 0.0f, 0, il, true, false, true);
|
||||
} else {
|
||||
// Pre-attention norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
@@ -6862,8 +6931,13 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
|
||||
}
|
||||
|
||||
// residual connection for attention output
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
ggml_tensor * ffn_inp;
|
||||
if (rope_cache) {
|
||||
ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
} else {
|
||||
ffn_inp = cur;
|
||||
}
|
||||
|
||||
if ((uint32_t) il < hparams.n_layer_dense_lead) {
|
||||
// dense FFN
|
||||
@@ -6872,7 +6946,7 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf);
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
|
||||
@@ -6887,39 +6961,11 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale,
|
||||
(llm_expert_gating_func_type) hparams.expert_gating_func,
|
||||
LLM_FFN_SILU, cb, il, gf);
|
||||
|
||||
//// Post-attention norm
|
||||
//cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
//cb(cur, "post_attn_norm", il);
|
||||
//// MoE FFN
|
||||
//auto routed_out = llm_build_moe_ffn(ctx0, lctx, cur,
|
||||
// model.layers[il].ffn_gate_inp,
|
||||
// model.layers[il].ffn_up_exps,
|
||||
// model.layers[il].ffn_gate_exps,
|
||||
// model.layers[il].ffn_down_exps,
|
||||
// model.layers[il].ffn_exp_probs_b,
|
||||
// n_expert, n_expert_used,
|
||||
// LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
// true, hparams.expert_weights_scale,
|
||||
// (enum llm_expert_gating_func_type) hparams.expert_gating_func,
|
||||
// cb, il, gf);
|
||||
//cb(routed_out, "routed_out", il);
|
||||
|
||||
//auto shared_out = llm_build_ffn(ctx0, lctx, nullptr, cur,
|
||||
// model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
// model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
// model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
// NULL,
|
||||
// LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
//cb(shared_out, "ffn_shexp_out", il);
|
||||
|
||||
//cur = ggml_add(ctx0, routed_out, shared_out);
|
||||
//cb(cur, "ffn_out", il);
|
||||
LLM_FFN_SILU, cb, il, gf, true);
|
||||
}
|
||||
|
||||
// residual and context vector
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
//cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
@@ -7229,48 +7275,25 @@ ggml_cgraph * llm_build_context::build_cohere2() {
|
||||
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
|
||||
struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
|
||||
|
||||
// norm
|
||||
auto attn_norm = model.layers[il].attn_norm;
|
||||
int id = -1;
|
||||
if (attn_norm->extra) {
|
||||
auto extra = (ggml_split_tensor_t *)attn_norm->extra;
|
||||
for (int i = extra->n_device-1; i >= 0; --i) {
|
||||
if (extra->splits[i]) {
|
||||
attn_norm = extra->splits[i];
|
||||
id = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
cur = llm_build_norm(ctx0, inpL, hparams, attn_norm, NULL, LLM_NORM, cb, il);
|
||||
if (id >= 0) {
|
||||
ggml_backend_sched_set_tensor_backend(lctx.sched, cur->src[0], ggml_backend_sched_get_backend(lctx.sched, id));
|
||||
}
|
||||
cb(cur, "attn_norm", il);
|
||||
auto ffn_inp = cur;
|
||||
|
||||
// self-attention
|
||||
auto attn_out = build_std_attention(gf, nullptr, cur, inp_pos, nullptr, KQ_mask_l, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), 0.f,
|
||||
is_sliding ? hparams.n_swa : 0, il, is_sliding, true);
|
||||
auto attn_out = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask_l, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), 0.f,
|
||||
is_sliding ? hparams.n_swa : 0, il, is_sliding, false, true, true);
|
||||
cb(attn_out, "attn_out", il);
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
attn_out = ggml_get_rows(ctx0, attn_out, inp_out_ids);
|
||||
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
|
||||
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||
}
|
||||
|
||||
// feed-forward network
|
||||
cur = llm_build_ffn(ctx0, lctx, nullptr, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
|
||||
cur = llm_build_ffn(ctx0, lctx, model.layers[il].attn_norm, inpL, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
|
||||
NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
|
||||
cb, il, gf);
|
||||
cb, il, gf, false, true, attn_out);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
// add together residual + FFN + self-attention
|
||||
cur = ggml_add(ctx0, cur, attn_out);
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
@@ -7279,9 +7302,6 @@ ggml_cgraph * llm_build_context::build_cohere2() {
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
//if (cur->type != GGML_TYPE_F32) {
|
||||
// cur = ggml_cast(ctx0, cur, GGML_TYPE_F32);
|
||||
//}
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
@@ -9312,10 +9332,11 @@ ggml_cgraph * llm_build_context::llama_build_graph(
|
||||
ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tensor * the_attn_norm,
|
||||
ggml_tensor * input, ggml_tensor * inp_pos, ggml_tensor * rope_factors_in,
|
||||
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale,
|
||||
int n_swa, int il, bool do_rope, bool add_graph_split) {
|
||||
int n_swa, int il, bool do_rope, bool add_graph_split, bool add_input, bool is_norm) {
|
||||
if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn &&
|
||||
model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) {
|
||||
if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) {
|
||||
//printf("%s: %s\n", __func__, ggml_op_name(input->op));
|
||||
ggml_split_tensor_t * attn_norm = the_attn_norm ? (ggml_split_tensor_t *)the_attn_norm->extra : nullptr;
|
||||
auto wq = (ggml_split_tensor_t *)model.layers[il].wq->extra;
|
||||
auto wk = (ggml_split_tensor_t *)model.layers[il].wk->extra;
|
||||
@@ -9342,7 +9363,8 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
||||
bv = (ggml_split_tensor_t *)model.layers[il].bv->extra;
|
||||
GGML_ASSERT(bv->n_device == wq->n_device);
|
||||
}
|
||||
std::vector<ggml_tensor*> attn; attn.reserve(wq->n_device);
|
||||
std::vector<ggml_tensor*> attn(wq->n_device, nullptr);
|
||||
int id_last = -1;
|
||||
for (int id = 0; id < wq->n_device; ++id) {
|
||||
int il_cb = 1000*(id+1) + il;
|
||||
auto split_wq = wq->splits[id];
|
||||
@@ -9354,13 +9376,22 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
||||
GGML_ASSERT((!split_wq && !split_wk && !split_wv && !split_wo && !split_kl && !split_vl) ||
|
||||
(split_wq && split_wk && split_wv && split_wo && split_kl && split_vl));
|
||||
if (!split_wq) continue;
|
||||
auto cur = input;
|
||||
auto cur = get_input_tensor_sm_graph(input, id);
|
||||
if (attn_norm) {
|
||||
auto split_norm = attn_norm->splits[id];
|
||||
cur = llm_build_norm(ctx0, cur, hparams, split_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il_cb);
|
||||
if (is_norm) {
|
||||
cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM, cb, il);
|
||||
GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM);
|
||||
cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff;
|
||||
} else {
|
||||
cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
|
||||
}
|
||||
}
|
||||
else if (cur->type != GGML_TYPE_F32) {
|
||||
//if (attn_norm) {
|
||||
// auto split_norm = attn_norm->splits[id];
|
||||
// cur = llm_build_norm(ctx0, cur, hparams, split_norm, NULL, is_norm ? LLM_NORM : LLM_NORM_RMS, cb, il);
|
||||
// cb(cur, "attn_norm", il_cb);
|
||||
//}
|
||||
if (cur->type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx0, cur, GGML_TYPE_F32);
|
||||
}
|
||||
auto the_q_norm = model.layers[il].attn_q_norm ? model.layers[il].attn_q_norm->extra ?
|
||||
@@ -9486,42 +9517,24 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
||||
cur = ggml_cast(ctx0, cur, GGML_TYPE_F16);
|
||||
}
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
attn.push_back(cur);
|
||||
attn[id] = cur;
|
||||
id_last = id;
|
||||
}
|
||||
GGML_ASSERT(!attn.empty());
|
||||
if (attn.size() == 1) return attn.front();
|
||||
//if (attn.size() > 2 && attn.size()%2 == 0) {
|
||||
// for (int id = 0; id < int(attn.size()/2); ++id) {
|
||||
// attn[id] = ggml_add(ctx0, attn[2*id+0], attn[2*id+1]);
|
||||
// attn[id]->op_params[0] = 0xff;
|
||||
// }
|
||||
// attn.resize(attn.size()/2);
|
||||
// auto cur = ggml_add(ctx0, attn[0], attn[1]);
|
||||
// cur->op_params[0] = 0xff;
|
||||
// cur->op_params[0] = 0xff;
|
||||
// for (int id = 2; id < (int)attn.size(); ++id) {
|
||||
// cur = ggml_add(ctx0, cur, attn[id]);
|
||||
// cb(cur, "combine_attn", il);
|
||||
// }
|
||||
// return cur;
|
||||
//}
|
||||
auto cur = ggml_add(ctx0, attn[0], attn[1]);
|
||||
cb(cur, "combine_attn", il);
|
||||
cur->op_params[0] = 0xff;
|
||||
for (int id = 2; id < (int)attn.size(); ++id) {
|
||||
cur = ggml_add(ctx0, cur, attn[id]);
|
||||
cb(cur, "combine_attn", il);
|
||||
}
|
||||
if (attn.size() > 2) {
|
||||
cur->op_params[0] = 0xff;
|
||||
GGML_ASSERT(id_last >= 0);
|
||||
if (add_input) {
|
||||
attn[id_last] = ggml_add(ctx0, attn[id_last], input);
|
||||
cb(attn[id_last], "attn_out_with_input", il);
|
||||
}
|
||||
auto cur = ggml_reduce(ctx0, attn.data(), wq->n_device, GGML_OP_ADD);
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
cb(cur, "attn_combined", il);
|
||||
return cur;
|
||||
}
|
||||
}
|
||||
|
||||
auto cur = input;
|
||||
if (the_attn_norm) {
|
||||
cur = llm_build_norm(ctx0, cur, hparams, the_attn_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cur = llm_build_norm(ctx0, cur, hparams, the_attn_norm, NULL, is_norm ? LLM_NORM : LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
}
|
||||
|
||||
@@ -9549,5 +9562,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa);
|
||||
|
||||
if (add_input) {
|
||||
cb(cur, "attn_out", il);
|
||||
cur = ggml_add(ctx0, cur, input);
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
@@ -335,7 +335,8 @@ struct llm_build_context {
|
||||
ggml_tensor * act_scales,
|
||||
llm_ffn_op_type type_op,
|
||||
llm_ffn_gate_type type_gate,
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr);
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr, bool add_input = false,
|
||||
bool is_norm = false, ggml_tensor * add_extra = nullptr);
|
||||
|
||||
static ggml_tensor * llm_build_moe_ffn(ggml_context * ctx, llama_context & lctx,
|
||||
ggml_tensor * cur,
|
||||
@@ -351,7 +352,7 @@ struct llm_build_context {
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
llm_expert_gating_func_type gating_op,
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr);
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr, bool add_input = false);
|
||||
|
||||
static ggml_tensor * llm_build_moe_ffn(ggml_context * ctx, llama_context & lctx,
|
||||
ggml_tensor * cur,
|
||||
@@ -367,7 +368,7 @@ llm_expert_gating_func_type gating_op,
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
llm_expert_gating_func_type gating_op,
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr) {
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr, bool add_input = false) {
|
||||
return llm_build_moe_ffn(ctx, lctx, cur,
|
||||
gate_inp, nullptr,
|
||||
up_exps, nullptr,
|
||||
@@ -376,7 +377,7 @@ llm_expert_gating_func_type gating_op,
|
||||
exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
type_op, norm_w, scale_w, w_scale,
|
||||
gating_op, cb, il, graph);
|
||||
gating_op, cb, il, graph, add_input);
|
||||
}
|
||||
|
||||
static ggml_tensor * llm_build_std_moe_ffn(ggml_context * ctx, llama_context & lctx,
|
||||
@@ -398,7 +399,7 @@ llm_expert_gating_func_type gating_op,
|
||||
float w_scale,
|
||||
llm_expert_gating_func_type gating_op,
|
||||
llm_ffn_op_type type_op_shexp,
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph);
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input = false);
|
||||
|
||||
static ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids);
|
||||
|
||||
@@ -410,6 +411,6 @@ llm_expert_gating_func_type gating_op,
|
||||
|
||||
ggml_tensor * build_std_attention(ggml_cgraph * gf, ggml_tensor * attn_norm, ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * rope_factors,
|
||||
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale,
|
||||
int n_swa, int il, bool do_rope = true, bool add_graph_split = false);
|
||||
int n_swa, int il, bool do_rope = true, bool add_graph_split = false, bool add_input = false, bool is_norm = false);
|
||||
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user