mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
Vulkan needs f32 precision for flash attention
This commit is contained in:
@@ -10130,6 +10130,12 @@ static struct ggml_tensor * llm_build_kqv(
|
|||||||
0);
|
0);
|
||||||
cb(k, "k", il);
|
cb(k, "k", il);
|
||||||
|
|
||||||
|
#ifdef GGML_USE_VULKAN
|
||||||
|
constexpr bool use_f32_precision = true;
|
||||||
|
#else
|
||||||
|
constexpr bool use_f32_precision = false;
|
||||||
|
#endif
|
||||||
|
|
||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
|
|
||||||
if (cparams.flash_attn) {
|
if (cparams.flash_attn) {
|
||||||
@@ -10151,7 +10157,7 @@ static struct ggml_tensor * llm_build_kqv(
|
|||||||
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
|
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
|
||||||
// For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG.
|
// For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG.
|
||||||
// Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel.
|
// Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel.
|
||||||
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ||
|
if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ||
|
||||||
(model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) {
|
(model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) {
|
||||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||||
}
|
}
|
||||||
@@ -10176,7 +10182,7 @@ static struct ggml_tensor * llm_build_kqv(
|
|||||||
|
|
||||||
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||||
|
|
||||||
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 ||
|
if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 ||
|
||||||
model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) {
|
model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) {
|
||||||
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
||||||
@@ -15443,6 +15449,11 @@ struct llm_build_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_cgraph * build_deepseek2() {
|
struct ggml_cgraph * build_deepseek2() {
|
||||||
|
#ifdef GGML_USE_VULKAN
|
||||||
|
constexpr bool use_f32_attn_precision = true;
|
||||||
|
#else
|
||||||
|
constexpr bool use_f32_attn_precision = false;
|
||||||
|
#endif
|
||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||||
|
|
||||||
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
||||||
@@ -15672,7 +15683,7 @@ struct llm_build_context {
|
|||||||
q->nb[1], q->nb[2], q->nb[2]*n_max_head*iter);
|
q->nb[1], q->nb[2], q->nb[2]*n_max_head*iter);
|
||||||
|
|
||||||
kqv = ggml_flash_attn_ext(ctx0, q_iter, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
|
kqv = ggml_flash_attn_ext(ctx0, q_iter, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
|
||||||
if (q->ne[1] <= 8) {
|
if (use_f32_attn_precision || q->ne[1] <= 8) {
|
||||||
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
|
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
|
||||||
}
|
}
|
||||||
cb(kqv, "kqv", il);
|
cb(kqv, "kqv", il);
|
||||||
@@ -15714,6 +15725,10 @@ struct llm_build_context {
|
|||||||
kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
|
kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
|
||||||
cb(kqv_compressed, "kqv_compressed", il);
|
cb(kqv_compressed, "kqv_compressed", il);
|
||||||
|
|
||||||
|
if (use_f32_attn_precision) {
|
||||||
|
ggml_flash_attn_ext_set_prec(kqv_compressed, GGML_PREC_F32);
|
||||||
|
}
|
||||||
|
|
||||||
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
|
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
|
||||||
cb(kqv_compressed, "kqv_compressed_perm", il);
|
cb(kqv_compressed, "kqv_compressed_perm", il);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user