diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index b2f11b2a..d8fb6a4a 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -16847,14 +16847,17 @@ struct FlashQKfp32 { MAKE_FUNCS(mul_mat_qX_1_q8_2_T>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 4); MAKE_FUNCS(mul_mat_qX_1_q8_2_T>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_1_q8_1 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; -#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; +#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_1: { HelperQ41 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); @@ -17743,11 +17746,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ60 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; -#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; +#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_1: { HelperQ41 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); @@ -17770,7 +17773,8 @@ inline bool flash_attn_is_supported(ggml_type type) { if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL || type == GGML_TYPE_Q8_0_R8) return true; #else - if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV || type == GGML_TYPE_Q8_0_R8) return true; + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV || type == GGML_TYPE_Q8_0_R8 + || type == GGML_TYPE_Q4_0) return true; #endif return false; }