diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ce868514..a70310d4 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -7492,7 +7492,11 @@ struct FlashQKfp32 { #ifdef __aarch64__ mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); #else - mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); + if constexpr (D >= 128) { + mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); + } else { + mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); + } #endif } else if constexpr (std::is_same_v>) { @@ -7523,7 +7527,7 @@ struct FlashQKfp32 { const block_q8 * q, const char * mask, FlashMS& fms) { GGML_ASSERT(nq < 8); if constexpr (std::is_same_v>) { - DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; switch (nq) { #ifdef __aarch64__ case 1: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; @@ -7545,9 +7549,9 @@ struct FlashQKfp32 { } } else if constexpr (std::is_same_v>) { - DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; - switch (nq) { + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; #ifdef __aarch64__ + switch (nq) { case 1: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 2: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 3: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; @@ -7555,16 +7559,30 @@ struct FlashQKfp32 { case 5: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 6: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 7: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; -#else - case 1: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; -#endif } +#else + if constexpr (D >= 128) { + switch (nq) { + case 1: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + } + } else { + switch (nq) { + case 1: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + } + } +#endif } else if constexpr (std::is_same_v>) { DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};