mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-27 09:53:40 +00:00
WIP - factor out split attention
This commit is contained in:
@@ -9307,192 +9307,209 @@ ggml_cgraph * llm_build_context::llama_build_graph(
|
||||
return result;
|
||||
}
|
||||
|
||||
void llm_build_context::build_std_attention(ggml_cgraph * gf, std::vector<ggml_tensor *> & input, ggml_tensor * inp_pos, ggml_tensor * rope_factors_in,
|
||||
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale, int n_swa, int il) {
|
||||
GGML_ASSERT(cparams.flash_attn);
|
||||
GGML_ASSERT(!model.layers[il].wqkv);
|
||||
GGML_ASSERT(!model.layers[il].wqk);
|
||||
GGML_ASSERT(model.layers[il].wq && model.layers[il].wq->extra);
|
||||
GGML_ASSERT(model.layers[il].wk && model.layers[il].wk->extra);
|
||||
GGML_ASSERT(model.layers[il].wv && model.layers[il].wv->extra);
|
||||
GGML_ASSERT(kv_self.k_l[il]->extra && kv_self.v_l[il]->extra);
|
||||
auto attn_norm = model.layers[il].attn_norm ? (ggml_split_tensor_t *)model.layers[il].attn_norm->extra : nullptr;
|
||||
auto wq = (ggml_split_tensor_t *)model.layers[il].wq->extra;
|
||||
auto wk = (ggml_split_tensor_t *)model.layers[il].wk->extra;
|
||||
auto wv = (ggml_split_tensor_t *)model.layers[il].wv->extra;
|
||||
auto wo = (ggml_split_tensor_t *)model.layers[il].wo->extra;
|
||||
GGML_ASSERT(wq->n_device == wk->n_device && wq->n_device == wv->n_device && wq->n_device == wo->n_device);
|
||||
GGML_ASSERT(wq->n_device == int(input.size()));
|
||||
auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra;
|
||||
auto vl = (ggml_split_tensor_t *)kv_self.v_l[il]->extra;
|
||||
GGML_ASSERT(wq->n_device == kl->n_device && wq->n_device == vl->n_device);
|
||||
ggml_split_tensor_t *bq = nullptr, *bo = nullptr, *bk = nullptr, *bv = nullptr;
|
||||
if (model.layers[il].bq && model.layers[il].bq->extra) {
|
||||
bq = (ggml_split_tensor_t *)model.layers[il].bq->extra;
|
||||
GGML_ASSERT(bq->n_device == wq->n_device);
|
||||
}
|
||||
if (model.layers[il].bo && model.layers[il].bo->extra) {
|
||||
bo = (ggml_split_tensor_t *)model.layers[il].bo->extra;
|
||||
GGML_ASSERT(bo->n_device == wq->n_device);
|
||||
}
|
||||
if (model.layers[il].bk && model.layers[il].bk->extra) {
|
||||
bk = (ggml_split_tensor_t *)model.layers[il].bk->extra;
|
||||
GGML_ASSERT(bk->n_device == wq->n_device);
|
||||
}
|
||||
if (model.layers[il].bv && model.layers[il].bv->extra) {
|
||||
bv = (ggml_split_tensor_t *)model.layers[il].bv->extra;
|
||||
GGML_ASSERT(bv->n_device == wq->n_device);
|
||||
}
|
||||
for (int id = 0; id < wq->n_device; ++id) {
|
||||
int il_cb = 1000*(id+1) + il;
|
||||
auto split_wq = wq->splits[id];
|
||||
auto split_wk = wk->splits[id];
|
||||
auto split_wv = wv->splits[id];
|
||||
auto split_wo = wo->splits[id];
|
||||
auto split_kl = kl->splits[id];
|
||||
auto split_vl = vl->splits[id];
|
||||
GGML_ASSERT((!split_wq && !split_wk && !split_wv && !split_wo && !split_kl && !split_vl) ||
|
||||
(split_wq && split_wk && split_wv && split_wo && split_kl && split_vl));
|
||||
if (!split_wq) {
|
||||
GGML_ASSERT(input[id] == nullptr);
|
||||
continue;
|
||||
}
|
||||
auto cur = input[id];
|
||||
if (attn_norm) {
|
||||
cur = llm_build_norm(ctx0, cur, hparams, attn_norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il_cb);
|
||||
}
|
||||
else if (cur->type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx0, cur, GGML_TYPE_F32);
|
||||
}
|
||||
auto the_q_norm = model.layers[il].attn_q_norm ? model.layers[il].attn_q_norm->extra ?
|
||||
((ggml_split_tensor_t *)model.layers[il].attn_q_norm->extra)->splits[id] : model.layers[il].attn_q_norm : nullptr;
|
||||
auto the_k_norm = model.layers[il].attn_k_norm ? model.layers[il].attn_k_norm->extra ?
|
||||
((ggml_split_tensor_t *)model.layers[il].attn_k_norm->extra)->splits[id] : model.layers[il].attn_k_norm : nullptr;
|
||||
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, nullptr, nullptr, nullptr, nullptr,
|
||||
split_wq, bq ? bq->splits[id] : nullptr,
|
||||
split_wk, bk ? bk->splits[id] : nullptr,
|
||||
split_wv, bv ? bv->splits[id] : nullptr,
|
||||
the_q_norm, the_k_norm, f_attn_scale, il_cb);
|
||||
auto rope_factors = rope_factors_in;
|
||||
if (!rope_factors && model.layers[il].rope_freqs && model.layers[il].rope_freqs->extra) {
|
||||
auto extra = (ggml_split_tensor_t *)model.layers[il].rope_freqs->extra;
|
||||
rope_factors = extra->splits[id];
|
||||
}
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur", il_cb);
|
||||
cb(Kcur, "Kcur", il_cb);
|
||||
if (inp_attn_scale) {
|
||||
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
|
||||
cb(Qcur, "Qcur_temp_scaled", il_cb);
|
||||
}
|
||||
if (cparams.k_cache_hadamard) {
|
||||
Qcur = ggml_hadamard(ctx0, Qcur, hparams.n_embd_head_k);
|
||||
Kcur = ggml_hadamard(ctx0, Kcur, hparams.n_embd_head_k);
|
||||
cb(Qcur, "Qcur_hadamard", il_cb);
|
||||
cb(Kcur, "Kcur_hadamard", il_cb);
|
||||
}
|
||||
ggml_build_forward_expand(gf, Qcur);
|
||||
ggml_build_forward_expand(gf, Kcur);
|
||||
ggml_build_forward_expand(gf, Vcur);
|
||||
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||
const int64_t n_head_kv = split_wk->ne[1] / n_embd_head_k;
|
||||
|
||||
GGML_ASSERT(kv_self.size == cparams.n_ctx);
|
||||
|
||||
auto idx = 2*wq->n_device*il + 2*id;
|
||||
GGML_ASSERT(idx+1 < (int)lctx.cache_copies.size());
|
||||
auto k_row_size = ggml_row_size(split_kl->type, n_embd_head_k);
|
||||
ggml_tensor * k_cache_view = ggml_view_2d(ctx0, split_kl, n_embd_head_k, n_tokens*n_head_kv,
|
||||
k_row_size, k_row_size*n_head_kv*kv_head);
|
||||
|
||||
lctx.cache_copies[idx+0].cpy = ggml_cpy(ctx0, Kcur, k_cache_view);
|
||||
lctx.cache_copies[idx+0].step = k_row_size*n_head_kv;
|
||||
|
||||
// note: storing RoPE-ed version of K in the KV cache
|
||||
ggml_build_forward_expand(gf, lctx.cache_copies[idx+0].cpy);
|
||||
|
||||
auto v_cache_view = ggml_view_1d(ctx0, split_vl, n_tokens*split_wv->ne[1],
|
||||
kv_head*ggml_row_size(split_vl->type, split_wv->ne[1]));
|
||||
cb(v_cache_view, "v_cache_view", il_cb);
|
||||
lctx.cache_copies[idx+1].step = ggml_row_size(split_vl->type, split_wv->ne[1]);
|
||||
lctx.cache_copies[idx+1].cpy = ggml_cpy(ctx0, Vcur, v_cache_view);
|
||||
ggml_build_forward_expand(gf, lctx.cache_copies[idx+1].cpy);
|
||||
|
||||
auto q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
cb(q, "q", il_cb);
|
||||
|
||||
auto k = ggml_view_3d(ctx0, split_kl, n_embd_head_k, n_kv, n_head_kv,
|
||||
ggml_row_size(split_kl->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa),
|
||||
ggml_row_size(split_kl->type, n_embd_head_k), 0);
|
||||
cb(k, "k", il_cb);
|
||||
|
||||
auto v = ggml_view_3d(ctx0, split_vl, n_embd_head_v, n_kv, n_head_kv,
|
||||
ggml_row_size(split_vl->type, split_wv->ne[1]),
|
||||
ggml_row_size(split_vl->type, n_embd_head_v), 0);
|
||||
cb(v, "v", il_cb);
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
constexpr bool use_f32_precision = true;
|
||||
#else
|
||||
constexpr bool use_f32_precision = false;
|
||||
#endif
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, KQ_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
cb(cur, "flash_attn", il_cb);
|
||||
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
||||
if (n_swa > 0) {
|
||||
((int32_t *)cur->op_params)[4] = n_swa;
|
||||
}
|
||||
|
||||
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
|
||||
if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ||
|
||||
(model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 ||
|
||||
model.arch == LLM_ARCH_GLM4_MOE) {
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens);
|
||||
cb(cur, "flash_attn_reshaped", il_cb);
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, split_wo, cur);
|
||||
if (lctx.model.arch == LLM_ARCH_GLM4 || lctx.model.arch == LLM_ARCH_GLM4_MOE) {
|
||||
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
cb(cur, "kqv_wo", il_cb);
|
||||
if (bo) {
|
||||
cur = ggml_add(ctx0, cur, bo->splits[id]);
|
||||
cb(cur, "kqv_wo_biased", il_cb);
|
||||
}
|
||||
if (cur->ne[1] >= 32) {
|
||||
cur = ggml_cast(ctx0, cur, GGML_TYPE_F16);
|
||||
}
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
input[id] = cur;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tensor * input, ggml_tensor * inp_pos, ggml_tensor * rope_factors_in,
|
||||
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale, int n_swa, int il) {
|
||||
if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn &&
|
||||
model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) {
|
||||
if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) {
|
||||
ggml_split_tensor_t * attn_norm = model.layers[il].attn_norm ? (ggml_split_tensor_t *)model.layers[il].attn_norm->extra : nullptr;
|
||||
auto wq = (ggml_split_tensor_t *)model.layers[il].wq->extra;
|
||||
auto wk = (ggml_split_tensor_t *)model.layers[il].wk->extra;
|
||||
auto wv = (ggml_split_tensor_t *)model.layers[il].wv->extra;
|
||||
auto wo = (ggml_split_tensor_t *)model.layers[il].wo->extra;
|
||||
GGML_ASSERT(wq->n_device == wk->n_device && wq->n_device == wv->n_device && wq->n_device == wo->n_device);
|
||||
auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra;
|
||||
auto vl = (ggml_split_tensor_t *)kv_self.v_l[il]->extra;
|
||||
GGML_ASSERT(wq->n_device == kl->n_device && wq->n_device == vl->n_device);
|
||||
ggml_split_tensor_t *bq = nullptr, *bo = nullptr, *bk = nullptr, *bv = nullptr;
|
||||
if (model.layers[il].bq && model.layers[il].bq->extra) {
|
||||
bq = (ggml_split_tensor_t *)model.layers[il].bq->extra;
|
||||
GGML_ASSERT(bq->n_device == wq->n_device);
|
||||
model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra &&
|
||||
kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) {
|
||||
auto wq = (ggml_split_tensor_t *)model.layers[il].wq->extra;
|
||||
auto wk = (ggml_split_tensor_t *)model.layers[il].wk->extra;
|
||||
auto wv = (ggml_split_tensor_t *)model.layers[il].wv->extra;
|
||||
auto wo = (ggml_split_tensor_t *)model.layers[il].wo->extra;
|
||||
GGML_ASSERT(wq->n_device == wk->n_device && wq->n_device == wv->n_device && wq->n_device == wo->n_device);
|
||||
auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra;
|
||||
auto vl = (ggml_split_tensor_t *)kv_self.v_l[il]->extra;
|
||||
GGML_ASSERT(wq->n_device == kl->n_device && wq->n_device == vl->n_device);
|
||||
std::vector<ggml_tensor *> attn(wq->n_device, nullptr);
|
||||
std::vector<int> ids; ids.reserve(wq->n_device);
|
||||
for (int id = 0; id < wq->n_device; ++id) {
|
||||
if (wq->splits[id]) {
|
||||
attn[id] = input;
|
||||
ids.push_back(id);
|
||||
}
|
||||
if (model.layers[il].bo && model.layers[il].bo->extra) {
|
||||
bo = (ggml_split_tensor_t *)model.layers[il].bo->extra;
|
||||
GGML_ASSERT(bo->n_device == wq->n_device);
|
||||
}
|
||||
if (model.layers[il].bk && model.layers[il].bk->extra) {
|
||||
bk = (ggml_split_tensor_t *)model.layers[il].bk->extra;
|
||||
GGML_ASSERT(bk->n_device == wq->n_device);
|
||||
}
|
||||
if (model.layers[il].bv && model.layers[il].bv->extra) {
|
||||
bv = (ggml_split_tensor_t *)model.layers[il].bv->extra;
|
||||
GGML_ASSERT(bv->n_device == wq->n_device);
|
||||
}
|
||||
std::vector<ggml_tensor*> attn; attn.reserve(wq->n_device);
|
||||
for (int id = 0; id < wq->n_device; ++id) {
|
||||
int il_cb = 1000*(id+1) + il;
|
||||
auto split_wq = wq->splits[id];
|
||||
auto split_wk = wk->splits[id];
|
||||
auto split_wv = wv->splits[id];
|
||||
auto split_wo = wo->splits[id];
|
||||
auto split_kl = kl->splits[id];
|
||||
auto split_vl = vl->splits[id];
|
||||
GGML_ASSERT((!split_wq && !split_wk && !split_wv && !split_wo && !split_kl && !split_vl) ||
|
||||
(split_wq && split_wk && split_wv && split_wo && split_kl && split_vl));
|
||||
if (!split_wq) continue;
|
||||
auto cur = input;
|
||||
if (attn_norm) {
|
||||
auto split_norm = attn_norm->splits[id];
|
||||
cur = llm_build_norm(ctx0, cur, hparams, split_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il_cb);
|
||||
}
|
||||
else if (cur->type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx0, cur, GGML_TYPE_F32);
|
||||
}
|
||||
auto the_q_norm = model.layers[il].attn_q_norm ? model.layers[il].attn_q_norm->extra ?
|
||||
((ggml_split_tensor_t *)model.layers[il].attn_q_norm->extra)->splits[id] : model.layers[il].attn_q_norm : nullptr;
|
||||
auto the_k_norm = model.layers[il].attn_k_norm ? model.layers[il].attn_k_norm->extra ?
|
||||
((ggml_split_tensor_t *)model.layers[il].attn_k_norm->extra)->splits[id] : model.layers[il].attn_k_norm : nullptr;
|
||||
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, nullptr, nullptr, nullptr, nullptr,
|
||||
split_wq, bq ? bq->splits[id] : nullptr,
|
||||
split_wk, bk ? bk->splits[id] : nullptr,
|
||||
split_wv, bv ? bv->splits[id] : nullptr,
|
||||
the_q_norm, the_k_norm, f_attn_scale, il_cb);
|
||||
auto rope_factors = rope_factors_in;
|
||||
if (!rope_factors && model.layers[il].rope_freqs && model.layers[il].rope_freqs->extra) {
|
||||
auto extra = (ggml_split_tensor_t *)model.layers[il].rope_freqs->extra;
|
||||
rope_factors = extra->splits[id];
|
||||
}
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur", il_cb);
|
||||
cb(Kcur, "Kcur", il_cb);
|
||||
if (inp_attn_scale) {
|
||||
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
|
||||
cb(Qcur, "Qcur_temp_scaled", il_cb);
|
||||
}
|
||||
if (cparams.k_cache_hadamard) {
|
||||
Qcur = ggml_hadamard(ctx0, Qcur, hparams.n_embd_head_k);
|
||||
Kcur = ggml_hadamard(ctx0, Kcur, hparams.n_embd_head_k);
|
||||
cb(Qcur, "Qcur_hadamard", il_cb);
|
||||
cb(Kcur, "Kcur_hadamard", il_cb);
|
||||
}
|
||||
ggml_build_forward_expand(gf, Qcur);
|
||||
ggml_build_forward_expand(gf, Kcur);
|
||||
ggml_build_forward_expand(gf, Vcur);
|
||||
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||
const int64_t n_head_kv = split_wk->ne[1] / n_embd_head_k;
|
||||
|
||||
GGML_ASSERT(kv_self.size == cparams.n_ctx);
|
||||
|
||||
auto idx = 2*wq->n_device*il + 2*id;
|
||||
GGML_ASSERT(idx+1 < (int)lctx.cache_copies.size());
|
||||
auto k_row_size = ggml_row_size(split_kl->type, n_embd_head_k);
|
||||
ggml_tensor * k_cache_view = ggml_view_2d(ctx0, split_kl, n_embd_head_k, n_tokens*n_head_kv,
|
||||
k_row_size, k_row_size*n_head_kv*kv_head);
|
||||
|
||||
lctx.cache_copies[idx+0].cpy = ggml_cpy(ctx0, Kcur, k_cache_view);
|
||||
lctx.cache_copies[idx+0].step = k_row_size*n_head_kv;
|
||||
|
||||
// note: storing RoPE-ed version of K in the KV cache
|
||||
ggml_build_forward_expand(gf, lctx.cache_copies[idx+0].cpy);
|
||||
|
||||
struct ggml_tensor * v_cache_view = nullptr;
|
||||
|
||||
if (cparams.flash_attn) {
|
||||
v_cache_view = ggml_view_1d(ctx0, split_vl, n_tokens*split_wv->ne[1],
|
||||
kv_head*ggml_row_size(split_vl->type, split_wv->ne[1]));
|
||||
lctx.cache_copies[idx+1].step = ggml_row_size(split_vl->type, split_wv->ne[1]);
|
||||
} else {
|
||||
// note: the V cache is transposed when not using flash attention
|
||||
v_cache_view = ggml_view_2d(ctx0, split_vl, n_tokens, split_wv->ne[1],
|
||||
( n_ctx)*ggml_element_size(split_vl),
|
||||
(kv_head)*ggml_element_size(split_vl));
|
||||
lctx.cache_copies[idx+1].step = ggml_element_size(split_vl);
|
||||
|
||||
Vcur = ggml_transpose(ctx0, Vcur);
|
||||
}
|
||||
cb(v_cache_view, "v_cache_view", il_cb);
|
||||
|
||||
lctx.cache_copies[idx+1].cpy = ggml_cpy(ctx0, Vcur, v_cache_view);
|
||||
ggml_build_forward_expand(gf, lctx.cache_copies[idx+1].cpy);
|
||||
|
||||
auto q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
cb(q, "q", il_cb);
|
||||
|
||||
auto k = ggml_view_3d(ctx0, split_kl, n_embd_head_k, n_kv, n_head_kv,
|
||||
ggml_row_size(split_kl->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa),
|
||||
ggml_row_size(split_kl->type, n_embd_head_k), 0);
|
||||
cb(k, "k", il_cb);
|
||||
|
||||
auto v = ggml_view_3d(ctx0, split_vl, n_embd_head_v, n_kv, n_head_kv,
|
||||
ggml_row_size(split_vl->type, split_wv->ne[1]),
|
||||
ggml_row_size(split_vl->type, n_embd_head_v), 0);
|
||||
cb(v, "v", il_cb);
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
constexpr bool use_f32_precision = true;
|
||||
#else
|
||||
constexpr bool use_f32_precision = false;
|
||||
#endif
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, KQ_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
cb(cur, "flash_attn", il_cb);
|
||||
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
||||
if (n_swa > 0) {
|
||||
((int32_t *)cur->op_params)[4] = n_swa;
|
||||
}
|
||||
|
||||
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
|
||||
if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ||
|
||||
(model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 ||
|
||||
model.arch == LLM_ARCH_GLM4_MOE) {
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens);
|
||||
cb(cur, "flash_attn_reshaped", il_cb);
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, split_wo, cur);
|
||||
if (lctx.model.arch == LLM_ARCH_GLM4 || lctx.model.arch == LLM_ARCH_GLM4_MOE) {
|
||||
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
cb(cur, "kqv_wo", il_cb);
|
||||
if (bo) {
|
||||
cur = ggml_add(ctx0, cur, bo->splits[id]);
|
||||
cb(cur, "kqv_wo_biased", il_cb);
|
||||
}
|
||||
if (cur->ne[1] >= 32) {
|
||||
cur = ggml_cast(ctx0, cur, GGML_TYPE_F16);
|
||||
}
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
attn.push_back(cur);
|
||||
}
|
||||
GGML_ASSERT(!attn.empty());
|
||||
if (attn.size() == 1) return attn.front();
|
||||
auto cur = ggml_add(ctx0, attn[0], attn[1]);
|
||||
cb(cur, "combine_attn", il);
|
||||
cur->op_params[0] = 0xff;
|
||||
for (int id = 2; id < (int)attn.size(); ++id) {
|
||||
cur = ggml_add(ctx0, cur, attn[id]);
|
||||
cb(cur, "combine_attn", il);
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
GGML_ASSERT(!ids.empty());
|
||||
build_std_attention(gf, attn, inp_pos, rope_factors_in, KQ_mask, sinks, inp_attn_scale, KQ_scale, f_attn_scale, n_swa, il);
|
||||
|
||||
if (ids.size() == 1) return attn[ids.front()];
|
||||
auto cur = ggml_add(ctx0, attn[ids[0]], attn[ids[1]]);
|
||||
cb(cur, "combine_attn", il);
|
||||
cur->op_params[0] = 0xff;
|
||||
for (int id = 2; id < (int)ids.size(); ++id) {
|
||||
cur = ggml_add(ctx0, cur, attn[ids[id]]);
|
||||
cb(cur, "combine_attn", il);
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
auto cur = input;
|
||||
|
||||
@@ -410,4 +410,7 @@ llm_expert_gating_func_type gating_op,
|
||||
ggml_tensor * build_std_attention(ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * rope_factors,
|
||||
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale, int n_swa, int il);
|
||||
|
||||
void build_std_attention(ggml_cgraph * gf, std::vector<ggml_tensor *> & cur, ggml_tensor * inp_pos, ggml_tensor * rope_factors,
|
||||
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale, int n_swa, int il);
|
||||
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user