mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-12 23:10:01 +00:00
Use mul_mat_qX_0_q8_2_Tx for q4_0 in FA
This commit is contained in:
@@ -16847,14 +16847,17 @@ struct FlashQKfp32 {
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q6_0_1_Unpacker, nq);
|
||||
#endif
|
||||
}
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
|
||||
#else
|
||||
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 1, k_step>, 1);
|
||||
if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 2, k_step>, 2);
|
||||
if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q4_0_1_Unpacker, 4, k_step>, 4);
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q4_0_1_Unpacker, nq);
|
||||
#endif
|
||||
}
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_1<DequantizerQ41, nq);
|
||||
@@ -17698,11 +17701,11 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
|
||||
HelperQ60<Dv, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
case GGML_TYPE_Q4_0: {
|
||||
HelperQ40<Dv, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
case GGML_TYPE_Q4_1: {
|
||||
HelperQ41<Dv, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
@@ -17743,11 +17746,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
||||
HelperQ60<Dk, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
case GGML_TYPE_Q4_0: {
|
||||
HelperQ40<Dk, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
} break;
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
case GGML_TYPE_Q4_1: {
|
||||
HelperQ41<Dk, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
@@ -17770,7 +17773,8 @@ inline bool flash_attn_is_supported(ggml_type type) {
|
||||
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
|
||||
type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL || type == GGML_TYPE_Q8_0_R8) return true;
|
||||
#else
|
||||
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV || type == GGML_TYPE_Q8_0_R8) return true;
|
||||
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV || type == GGML_TYPE_Q8_0_R8
|
||||
|| type == GGML_TYPE_Q4_0) return true;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user