Use mul_mat_qX_0_q8_2_Tx for q4_0 in FA

This commit is contained in:
Iwan Kawrakow
2025-04-23 14:35:34 +03:00
parent ddcdf25e54
commit b19fd13141

View File

@@ -16847,14 +16847,17 @@ struct FlashQKfp32 {
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q6_0_1_Unpacker, nq);
#endif
}
#if GGML_IQK_FA_ALL_QUANTS
else if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
#else
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 1, k_step>, 1);
if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 2, k_step>, 2);
if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 4, k_step>, 4);
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q4_0_1_Unpacker, nq);
#endif
}
#if GGML_IQK_FA_ALL_QUANTS
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_1_q8_1<DequantizerQ41, nq);
@@ -17698,11 +17701,11 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperQ60<Dv, k_step> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
} break;
#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: {
HelperQ40<Dv, k_step> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
} break;
#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_1: {
HelperQ41<Dv, k_step> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
@@ -17743,11 +17746,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperQ60<Dk, k_step> kh(k, stride_k);
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
} break;
#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: {
HelperQ40<Dk, k_step> kh(k, stride_k);
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
} break;
#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_1: {
HelperQ41<Dk, k_step> kh(k, stride_k);
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
@@ -17770,7 +17773,8 @@ inline bool flash_attn_is_supported(ggml_type type) {
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL || type == GGML_TYPE_Q8_0_R8) return true;
#else
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV || type == GGML_TYPE_Q8_0_R8) return true;
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV || type == GGML_TYPE_Q8_0_R8
|| type == GGML_TYPE_Q4_0) return true;
#endif
return false;
}