mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 03:11:51 +00:00
Move the Qwen-3.5 models to the standard attention mechanism (#1329)
This commit is contained in:
@@ -1827,9 +1827,9 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_con
|
|||||||
cb(Kcur, "Kcur_normed", il);
|
cb(Kcur, "Kcur_normed", il);
|
||||||
ggml_build_forward_expand(gf, Kcur);
|
ggml_build_forward_expand(gf, Kcur);
|
||||||
}
|
}
|
||||||
gate = ggml_sigmoid(ctx0, gate);
|
//gate = ggml_sigmoid(ctx0, gate);
|
||||||
//gate = ggml_reshape_2d(ctx0, gate, gate->ne[0]*gate->ne[1], gate->ne[2]);
|
//gate = ggml_reshape_2d(ctx0, gate, gate->ne[0]*gate->ne[1], gate->ne[2]);
|
||||||
cb(gate, "gate", il);
|
//cb(gate, "gate", il);
|
||||||
return {Qcur, Kcur, Vcur, gate};
|
return {Qcur, Kcur, Vcur, gate};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4536,62 +4536,6 @@ ggml_cgraph * llm_build_context::build_qwen35moe() {
|
|||||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
|
||||||
int sections[4];
|
|
||||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
|
||||||
|
|
||||||
auto build_layer_attn = [&](ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * KQ_mask, int il) -> ggml_tensor * {
|
|
||||||
|
|
||||||
auto Qaux = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
||||||
auto Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
||||||
auto Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
||||||
cb(Qaux, "Qaux", il);
|
|
||||||
cb(Kcur, "Kcur", il);
|
|
||||||
cb(Vcur, "Vcur", il);
|
|
||||||
ggml_build_forward_expand(gf, Qaux);
|
|
||||||
ggml_build_forward_expand(gf, Kcur);
|
|
||||||
ggml_build_forward_expand(gf, Vcur);
|
|
||||||
|
|
||||||
Qaux = ggml_reshape_3d(ctx0, Qaux, n_embd_head * 2, n_head, n_tokens);
|
|
||||||
auto Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head, n_head, n_tokens, Qaux->nb[1], Qaux->nb[2], 0));
|
|
||||||
auto gate = ggml_cont_2d(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head, n_head, n_tokens, Qaux->nb[1], Qaux->nb[2], n_embd_head*ggml_element_size(Qaux)), n_embd_head*n_head, n_tokens);
|
|
||||||
cb(Qcur, "Qcur", il);
|
|
||||||
cb(gate, "gate", il);
|
|
||||||
|
|
||||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
||||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
|
||||||
|
|
||||||
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
|
|
||||||
cb(Qcur, "Qcur_normed", il);
|
|
||||||
|
|
||||||
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
|
|
||||||
cb(Kcur, "Kcur_normed", il);
|
|
||||||
|
|
||||||
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
|
|
||||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
|
||||||
|
|
||||||
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
|
|
||||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
|
||||||
|
|
||||||
cb(Qcur, "Qcur_roped", il);
|
|
||||||
cb(Kcur, "Kcur_roped", il);
|
|
||||||
|
|
||||||
ggml_tensor * attn = llm_build_kv(ctx0, lctx, kv_self, gf, nullptr, nullptr,
|
|
||||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv,
|
|
||||||
hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale, cb, il);
|
|
||||||
cb(attn, "attn_pregate", il);
|
|
||||||
|
|
||||||
gate = ggml_sigmoid(ctx0, gate);
|
|
||||||
cb(gate, "gate_sigmoid", il);
|
|
||||||
attn = ggml_mul(ctx0, attn, gate);
|
|
||||||
cb(attn, "attn_gated", il);
|
|
||||||
|
|
||||||
attn = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, attn);
|
|
||||||
cb(attn, "attn_output", il);
|
|
||||||
|
|
||||||
return attn;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||||
ggml_tensor * inp_pos = build_inp_pos();
|
ggml_tensor * inp_pos = build_inp_pos();
|
||||||
ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
|
ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
|
||||||
@@ -4601,6 +4545,8 @@ ggml_cgraph * llm_build_context::build_qwen35moe() {
|
|||||||
cb(lctx.inp_s_seq_qnext, "inp_s_seq_qnext", -1);
|
cb(lctx.inp_s_seq_qnext, "inp_s_seq_qnext", -1);
|
||||||
ggml_set_input(lctx.inp_s_seq_qnext);
|
ggml_set_input(lctx.inp_s_seq_qnext);
|
||||||
|
|
||||||
|
float KQ_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||||
|
|
||||||
ggml_tensor * causal_mask = nullptr;
|
ggml_tensor * causal_mask = nullptr;
|
||||||
ggml_tensor * identity = nullptr;
|
ggml_tensor * identity = nullptr;
|
||||||
ggml_tensor * diag_mask = nullptr;
|
ggml_tensor * diag_mask = nullptr;
|
||||||
@@ -4616,25 +4562,26 @@ ggml_cgraph * llm_build_context::build_qwen35moe() {
|
|||||||
ggml_tensor * cur = nullptr;
|
ggml_tensor * cur = nullptr;
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
ggml_tensor * inpSA = inpL;
|
|
||||||
|
|
||||||
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
|
|
||||||
cb(cur, "attn_norm", il);
|
|
||||||
|
|
||||||
if (hparams.is_recurrent(il)) {
|
if (hparams.is_recurrent(il)) {
|
||||||
|
ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
cur = delta.build_layer_attn_linear(ctx0, gf, cur, causal_mask, identity, diag_mask, il, cb);
|
cur = delta.build_layer_attn_linear(ctx0, gf, cur, causal_mask, identity, diag_mask, il, cb);
|
||||||
|
if (il == n_layer - 1 && inp_out_ids) {
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(cur, "attn_residual", il);
|
||||||
} else {
|
} else {
|
||||||
cur = build_layer_attn(cur, inp_pos, KQ_mask, il);
|
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, il == n_layer - 1 ? inp_out_ids : nullptr, nullptr,
|
||||||
|
KQ_mask, nullptr, nullptr, KQ_scale, 0.0f, 0, il, true, false, true, false, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1 && inp_out_ids) {
|
|
||||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
||||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
||||||
}
|
|
||||||
|
|
||||||
cur = ggml_add(ctx0, cur, inpSA);
|
|
||||||
cb(cur, "attn_residual", il);
|
|
||||||
|
|
||||||
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur,
|
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur,
|
||||||
model.layers[il].ffn_gate_inp, nullptr,
|
model.layers[il].ffn_gate_inp, nullptr,
|
||||||
model.layers[il].ffn_up_exps, nullptr,
|
model.layers[il].ffn_up_exps, nullptr,
|
||||||
@@ -4673,62 +4620,6 @@ ggml_cgraph * llm_build_context::build_qwen35() {
|
|||||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
|
||||||
int sections[4];
|
|
||||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
|
||||||
|
|
||||||
auto build_layer_attn = [&](ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * KQ_mask, int il) -> ggml_tensor * {
|
|
||||||
|
|
||||||
auto Qaux = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
|
||||||
auto Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
|
||||||
auto Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
|
||||||
cb(Qaux, "Qaux", il);
|
|
||||||
cb(Kcur, "Kcur", il);
|
|
||||||
cb(Vcur, "Vcur", il);
|
|
||||||
ggml_build_forward_expand(gf, Qaux);
|
|
||||||
ggml_build_forward_expand(gf, Kcur);
|
|
||||||
ggml_build_forward_expand(gf, Vcur);
|
|
||||||
|
|
||||||
Qaux = ggml_reshape_3d(ctx0, Qaux, n_embd_head * 2, n_head, n_tokens);
|
|
||||||
auto Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head, n_head, n_tokens, Qaux->nb[1], Qaux->nb[2], 0));
|
|
||||||
auto gate = ggml_cont_2d(ctx0, ggml_view_3d(ctx0, Qaux, n_embd_head, n_head, n_tokens, Qaux->nb[1], Qaux->nb[2], n_embd_head*ggml_element_size(Qaux)), n_embd_head*n_head, n_tokens);
|
|
||||||
cb(Qcur, "Qcur", il);
|
|
||||||
cb(gate, "gate", il);
|
|
||||||
|
|
||||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
||||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
|
||||||
|
|
||||||
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
|
|
||||||
cb(Qcur, "Qcur_normed", il);
|
|
||||||
|
|
||||||
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
|
|
||||||
cb(Kcur, "Kcur_normed", il);
|
|
||||||
|
|
||||||
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
|
|
||||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
|
||||||
|
|
||||||
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
|
|
||||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
|
||||||
|
|
||||||
cb(Qcur, "Qcur_roped", il);
|
|
||||||
cb(Kcur, "Kcur_roped", il);
|
|
||||||
|
|
||||||
ggml_tensor * attn = llm_build_kv(ctx0, lctx, kv_self, gf, nullptr, nullptr,
|
|
||||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv,
|
|
||||||
hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale, cb, il);
|
|
||||||
cb(attn, "attn_pregate", il);
|
|
||||||
|
|
||||||
gate = ggml_sigmoid(ctx0, gate);
|
|
||||||
cb(gate, "gate_sigmoid", il);
|
|
||||||
attn = ggml_mul(ctx0, attn, gate);
|
|
||||||
cb(attn, "attn_gated", il);
|
|
||||||
|
|
||||||
attn = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, attn);
|
|
||||||
cb(attn, "attn_output", il);
|
|
||||||
|
|
||||||
return attn;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||||
ggml_tensor * inp_pos = build_inp_pos();
|
ggml_tensor * inp_pos = build_inp_pos();
|
||||||
ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
|
ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr;
|
||||||
@@ -4738,6 +4629,8 @@ ggml_cgraph * llm_build_context::build_qwen35() {
|
|||||||
cb(lctx.inp_s_seq_qnext, "inp_s_seq_qnext", -1);
|
cb(lctx.inp_s_seq_qnext, "inp_s_seq_qnext", -1);
|
||||||
ggml_set_input(lctx.inp_s_seq_qnext);
|
ggml_set_input(lctx.inp_s_seq_qnext);
|
||||||
|
|
||||||
|
float KQ_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||||
|
|
||||||
ggml_tensor * causal_mask = nullptr;
|
ggml_tensor * causal_mask = nullptr;
|
||||||
ggml_tensor * identity = nullptr;
|
ggml_tensor * identity = nullptr;
|
||||||
ggml_tensor * diag_mask = nullptr;
|
ggml_tensor * diag_mask = nullptr;
|
||||||
@@ -4753,25 +4646,23 @@ ggml_cgraph * llm_build_context::build_qwen35() {
|
|||||||
ggml_tensor * cur = nullptr;
|
ggml_tensor * cur = nullptr;
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
ggml_tensor * inpSA = inpL;
|
|
||||||
|
|
||||||
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
|
|
||||||
cb(cur, "attn_norm", il);
|
|
||||||
|
|
||||||
if (hparams.is_recurrent(il)) {
|
if (hparams.is_recurrent(il)) {
|
||||||
|
ggml_tensor * inpSA = inpL;
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
cur = delta.build_layer_attn_linear(ctx0, gf, cur, causal_mask, identity, diag_mask, il, cb);
|
cur = delta.build_layer_attn_linear(ctx0, gf, cur, causal_mask, identity, diag_mask, il, cb);
|
||||||
|
if (il == n_layer - 1 && inp_out_ids) {
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
|
}
|
||||||
|
cur = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(cur, "attn_residual", il);
|
||||||
} else {
|
} else {
|
||||||
cur = build_layer_attn(cur, inp_pos, KQ_mask, il);
|
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, il == n_layer - 1 ? inp_out_ids : nullptr, nullptr,
|
||||||
|
KQ_mask, nullptr, nullptr, KQ_scale, 0.0f, 0, il, true, false, true, false, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1 && inp_out_ids) {
|
|
||||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
||||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
||||||
}
|
|
||||||
|
|
||||||
cur = ggml_add(ctx0, cur, inpSA);
|
|
||||||
cb(cur, "attn_residual", il);
|
|
||||||
|
|
||||||
cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur,
|
cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur,
|
||||||
model.layers[il].ffn_up, NULL, NULL,
|
model.layers[il].ffn_up, NULL, NULL,
|
||||||
model.layers[il].ffn_gate, NULL, NULL,
|
model.layers[il].ffn_gate, NULL, NULL,
|
||||||
@@ -10254,7 +10145,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
|||||||
auto the_k_norm = model.layers[il].attn_k_norm ? model.layers[il].attn_k_norm->extra ?
|
auto the_k_norm = model.layers[il].attn_k_norm ? model.layers[il].attn_k_norm->extra ?
|
||||||
((ggml_split_tensor_t *)model.layers[il].attn_k_norm->extra)->splits[id] : model.layers[il].attn_k_norm : nullptr;
|
((ggml_split_tensor_t *)model.layers[il].attn_k_norm->extra)->splits[id] : model.layers[il].attn_k_norm : nullptr;
|
||||||
ggml_tensor *Qcur, *Kcur, *Vcur, *gate = nullptr;
|
ggml_tensor *Qcur, *Kcur, *Vcur, *gate = nullptr;
|
||||||
if (model.arch == LLM_ARCH_QWEN3NEXT) {
|
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) {
|
||||||
auto [Q, K, V, G] = llm_build_mul_mat_qkv_gated(gf, cur, split_wq, split_wk, split_wv,
|
auto [Q, K, V, G] = llm_build_mul_mat_qkv_gated(gf, cur, split_wq, split_wk, split_wv,
|
||||||
the_q_norm, the_k_norm, il);
|
the_q_norm, the_k_norm, il);
|
||||||
Qcur = Q; Kcur = K; Vcur = V; gate = G;
|
Qcur = Q; Kcur = K; Vcur = V; gate = G;
|
||||||
@@ -10393,7 +10284,13 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
|||||||
cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens);
|
cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens);
|
||||||
cb(cur, "flash_attn_reshaped", il_cb);
|
cb(cur, "flash_attn_reshaped", il_cb);
|
||||||
if (gate) {
|
if (gate) {
|
||||||
cur = ggml_mul(ctx0, cur, gate);
|
if (false && cur->ne[1] == 1) { // we need to add GGML_UNARY_OP_SIGMOID to the ops supported by ggml_fused_mul_unary
|
||||||
|
cur = ggml_fused_mul_unary(ctx0, cur, gate, GGML_UNARY_OP_SIGMOID);
|
||||||
|
} else {
|
||||||
|
gate = ggml_sigmoid(ctx0, gate);
|
||||||
|
cb(gate, "gate", il_cb);
|
||||||
|
cur = ggml_mul(ctx0, cur, gate);
|
||||||
|
}
|
||||||
cb(cur, "qkv_gated", il_cb);
|
cb(cur, "qkv_gated", il_cb);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -10445,7 +10342,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
|||||||
auto input_normed = cur;
|
auto input_normed = cur;
|
||||||
|
|
||||||
ggml_tensor *Qcur, *Kcur, *Vcur, *gate = nullptr;
|
ggml_tensor *Qcur, *Kcur, *Vcur, *gate = nullptr;
|
||||||
if (model.arch == LLM_ARCH_QWEN3NEXT) {
|
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) {
|
||||||
auto [Q, K, V, G] = llm_build_mul_mat_qkv_gated(gf, cur, model.layers[il].wq, model.layers[il].wk, model.layers[il].wv,
|
auto [Q, K, V, G] = llm_build_mul_mat_qkv_gated(gf, cur, model.layers[il].wq, model.layers[il].wk, model.layers[il].wv,
|
||||||
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, il);
|
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, il);
|
||||||
Qcur = Q; Kcur = K; Vcur = V; gate = G;
|
Qcur = Q; Kcur = K; Vcur = V; gate = G;
|
||||||
@@ -10506,7 +10403,13 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
|||||||
if (gate) {
|
if (gate) {
|
||||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf, nullptr, nullptr,
|
cur = llm_build_kv(ctx0, lctx, kv_self, gf, nullptr, nullptr,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa);
|
||||||
cur = ggml_mul(ctx0, cur, gate);
|
if (false && cur->ne[1] == 1) { // we need to add GGML_UNARY_OP_SIGMOID to the ops supported by ggml_fused_mul_unary
|
||||||
|
cur = ggml_fused_mul_unary(ctx0, cur, gate, GGML_UNARY_OP_SIGMOID);
|
||||||
|
} else {
|
||||||
|
gate = ggml_sigmoid(ctx0, gate);
|
||||||
|
cb(gate, "gate", il);
|
||||||
|
cur = ggml_mul(ctx0, cur, gate);
|
||||||
|
}
|
||||||
cb(cur, "qkv_gated", il);
|
cb(cur, "qkv_gated", il);
|
||||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
|
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
|
||||||
if (model.layers[il].bo) {
|
if (model.layers[il].bo) {
|
||||||
|
|||||||
Reference in New Issue
Block a user