mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 09:09:50 +00:00
Refinements
This commit is contained in:
@@ -1343,12 +1343,18 @@ static ggml_tensor * build_glm45_fa(ggml_context * ctx, ggml_tensor * q, ggml_te
|
||||
auto ne1 = 8*v->ne[0];
|
||||
auto ne2 = 4*v->ne[0];
|
||||
|
||||
auto q1 = ggml_view_3d(ctx, q, q->ne[0], 8, k->ne[2]*q->ne[1], q->nb[2], q->nb[1]/k->ne[2], 0);
|
||||
auto q2 = ggml_view_3d(ctx, q, q->ne[0], 4, k->ne[2]*q->ne[1], q->nb[2], q->nb[1]/k->ne[2], 8*q->ne[0]*ggml_element_size(q));
|
||||
q1 = ggml_reshape_3d(ctx, ggml_cont(ctx, q1), q->ne[0], 8*k->ne[2], q->ne[1]);
|
||||
q2 = ggml_reshape_3d(ctx, ggml_cont(ctx, q2), q->ne[0], 4*k->ne[2], q->ne[1]);
|
||||
q1 = ggml_permute(ctx, q1, 0, 2, 1, 3);
|
||||
q2 = ggml_permute(ctx, q2, 0, 2, 1, 3);
|
||||
ggml_tensor *q1, *q2;
|
||||
if (q->ne[1] == 1 && k->ne[2] == 1) {
|
||||
q1 = ggml_view_3d(ctx, q, q->ne[0], 1, 8, q->nb[1], q->nb[2], 0);
|
||||
q2 = ggml_view_3d(ctx, q, q->ne[0], 1, 4, q->nb[1], q->nb[2], 8*q->ne[0]*ggml_element_size(q));
|
||||
} else {
|
||||
q1 = ggml_view_3d(ctx, q, q->ne[0], 8, k->ne[2]*q->ne[1], q->nb[2], q->nb[1]/k->ne[2], 0);
|
||||
q2 = ggml_view_3d(ctx, q, q->ne[0], 4, k->ne[2]*q->ne[1], q->nb[2], q->nb[1]/k->ne[2], 8*q->ne[0]*ggml_element_size(q));
|
||||
q1 = ggml_reshape_3d(ctx, ggml_cont(ctx, q1), q->ne[0], 8*k->ne[2], q->ne[1]);
|
||||
q2 = ggml_reshape_3d(ctx, ggml_cont(ctx, q2), q->ne[0], 4*k->ne[2], q->ne[1]);
|
||||
q1 = ggml_permute(ctx, q1, 0, 2, 1, 3);
|
||||
q2 = ggml_permute(ctx, q2, 0, 2, 1, 3);
|
||||
}
|
||||
|
||||
auto fa1 = ggml_flash_attn_ext(ctx, q1, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
|
||||
if (should_use_f32_precision) {
|
||||
@@ -1435,7 +1441,7 @@ static ggml_tensor * llm_build_kqv(
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
if (/*k->ne[1] >= 8192 &&*/ q->ne[2] / k->ne[2] == 12 && !sinks && n_swa == 0 &&
|
||||
if (q->ne[1] == 1 && k->ne[1] >= 8192 && q->ne[2] / k->ne[2] == 12 && !sinks && n_swa == 0 &&
|
||||
k->view_src && k->view_src->buffer && !ggml_backend_buffer_is_host(k->view_src->buffer)) {
|
||||
cur = build_glm45_fa(ctx, q, k, v, kq_mask, kq_scale, should_use_f32_precision);
|
||||
} else {
|
||||
@@ -9225,6 +9231,22 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
||||
|
||||
float freq_base_l = n_swa > 0 ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
|
||||
float freq_scale_l = n_swa > 0 ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train;
|
||||
#ifdef GGML_USE_VULKAN
|
||||
constexpr bool use_f32_precision = true;
|
||||
#else
|
||||
constexpr bool use_f32_precision = false;
|
||||
#endif
|
||||
|
||||
bool should_use_f32_precision = use_f32_precision
|
||||
|| model.arch == LLM_ARCH_PHI2
|
||||
|| model.arch == LLM_ARCH_PHI3
|
||||
|| model.arch == LLM_ARCH_GPTNEOX
|
||||
|| model.arch == LLM_ARCH_QWEN2
|
||||
|| model.arch == LLM_ARCH_COHERE2
|
||||
|| model.arch == LLM_ARCH_GLM4
|
||||
// || model.arch == LLM_ARCH_GLM4_MOE
|
||||
|| model.arch == LLM_ARCH_MIMO2;
|
||||
// || (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8);
|
||||
|
||||
if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn &&
|
||||
model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) {
|
||||
@@ -9367,35 +9389,28 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
||||
ggml_row_size(split_vl->type, n_embd_head_v), 0);
|
||||
cb(v, "v", il_cb);
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
constexpr bool use_f32_precision = true;
|
||||
#else
|
||||
constexpr bool use_f32_precision = false;
|
||||
#endif
|
||||
if (/*k->ne[1] >= 8192 &&*/ q->ne[2] / k->ne[2] == 12 && !sinks && n_swa == 0 &&
|
||||
if (q->ne[1] == 1 && k->ne[1] >= 65536/k->ne[2] && q->ne[2] / k->ne[2] == 12 && !sinks && n_swa == 0 &&
|
||||
k->view_src && k->view_src->buffer && !ggml_backend_buffer_is_host(k->view_src->buffer)) {
|
||||
cur = build_glm45_fa(ctx0, q, k, v, KQ_mask, KQ_scale, use_f32_precision);
|
||||
cur = build_glm45_fa(ctx0, q, k, v, KQ_mask, KQ_scale, should_use_f32_precision);
|
||||
} else {
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, KQ_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
cb(cur, "flash_attn", il_cb);
|
||||
if (model.layers[il].attn_sinks && model.layers[il].attn_sinks->extra) {
|
||||
auto split = (ggml_split_tensor_t *)model.layers[il].attn_sinks->extra;
|
||||
GGML_ASSERT(split->n_device == wq->n_device);
|
||||
GGML_ASSERT(split->splits[id]);
|
||||
ggml_flash_attn_ext_add_sinks(cur, split->splits[id]);
|
||||
} else {
|
||||
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
||||
}
|
||||
if (n_swa > 0) {
|
||||
((int32_t *)cur->op_params)[4] = n_swa;
|
||||
}
|
||||
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
|
||||
if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ||
|
||||
(model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 ||
|
||||
model.arch == LLM_ARCH_GLM4_MOE) {
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, KQ_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
cb(cur, "flash_attn", il_cb);
|
||||
if (model.layers[il].attn_sinks && model.layers[il].attn_sinks->extra) {
|
||||
auto split = (ggml_split_tensor_t *)model.layers[il].attn_sinks->extra;
|
||||
GGML_ASSERT(split->n_device == wq->n_device);
|
||||
GGML_ASSERT(split->splits[id]);
|
||||
ggml_flash_attn_ext_add_sinks(cur, split->splits[id]);
|
||||
} else {
|
||||
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
||||
}
|
||||
if (n_swa > 0) {
|
||||
((int32_t *)cur->op_params)[4] = n_swa;
|
||||
}
|
||||
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
|
||||
if (should_use_f32_precision) {
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
}
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens);
|
||||
|
||||
Reference in New Issue
Block a user