diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index c639c8cd..7555dc14 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -13187,13 +13187,15 @@ struct FlashQKV { } } } - F16::Data v1, v2; - for (int l1 = 0; l1 < k_step; ++l1) { - vh.load(l1, i, v1, v2); + F16::Data v1, v2, v3, v4; + for (int l1 = 0; l1 < k_step; l1 += 2) { + vh.load(l1+0, i, v1, v2); + vh.load(l1+1, i, v3, v4); for (int j = 0; j < q_step; ++j) { - auto vs = F16::set1(fms.cache[k_step*j + l1]); - vk[2*j+0] = F16::fmadd(vk[2*j+0], v1, vs); - vk[2*j+1] = F16::fmadd(vk[2*j+1], v2, vs); + auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]); + auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]); + vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2); + vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2); } } for (int j = 0; j < q_step; ++j) { @@ -13221,13 +13223,15 @@ struct FlashQKV { } } } - F16::Data v1, v2; - for (int l1 = 0; l1 < k_step; ++l1) { - vh.load(l1, i, v1, v2); + F16::Data v1, v2, v3, v4; + for (int l1 = 0; l1 < k_step; l1 += 2) { + vh.load(l1+0, i, v1, v2); + vh.load(l1+1, i, v3, v4); for (int j = 0; j < nq1; ++j) { - auto vs = F16::set1(fms.cache[k_step*j + l1]); - vk[2*j+0] = F16::fmadd(vk[2*j+0], v1, vs); - vk[2*j+1] = F16::fmadd(vk[2*j+1], v2, vs); + auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]); + auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]); + vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2); + vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2); } } for (int j = 0; j < nq1; ++j) { @@ -14149,14 +14153,25 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str } #ifdef __AVX512BF16__ -template +template inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, float scale, float softcap, float * qkv) { HelperBF16 kh(k, stride_k); HelperBF16 vh(v, stride_v); - if (nq1 >= q_step) { - FlashAttnBF16 fa(scale, softcap); + if (nk1 >= 4096) { + if (nq1 >= 64) { + FlashAttnBF16 fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } + else if (nq1 >= 16) { + FlashAttnBF16 fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } + return; + } + if (nq1 >= 8) { + FlashAttnBF16 fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } else { FlashAttnBF16 fa(scale, softcap); @@ -14176,10 +14191,12 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperF16 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; +#ifdef HAVE_FANCY_SIMD case GGML_TYPE_BF16: { HelperBF16 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; +#endif case GGML_TYPE_Q8_0: { HelperQ80 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); @@ -14215,12 +14232,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperF16 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; -#ifdef HAVE_FANCY_SIMD case GGML_TYPE_Q8_0: { HelperQ80 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; -#endif case GGML_TYPE_Q4_0: { HelperQ40 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); @@ -14286,13 +14301,13 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types switch (D) { case 64: - iqk_flash_helper_T< 64, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; }