Fix bug and D < 128 case for Q8_0 k-cache (#52)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2024-09-13 07:19:47 +03:00
committed by GitHub
parent e25c2e7ec2
commit 2bafb03aac

View File

@@ -7492,7 +7492,11 @@ struct FlashQKfp32 {
#ifdef __aarch64__
mul_mat_qX_0_q8_0<DequantizerQ80_x4, q_step>(D, kh.block, kh.stride, info, k_step);
#else
mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
if constexpr (D >= 128) {
mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
} else {
mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
}
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
@@ -7523,7 +7527,7 @@ struct FlashQKfp32 {
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
GGML_ASSERT(nq < 8);
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
switch (nq) {
#ifdef __aarch64__
case 1: mul_mat_qX_0_q8_0<DequantizerQ40, 1>(D, kh.block, kh.stride, info, k_step); break;
@@ -7545,9 +7549,9 @@ struct FlashQKfp32 {
}
}
else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
switch (nq) {
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
#ifdef __aarch64__
switch (nq) {
case 1: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 1>(D, kh.block, kh.stride, info, k_step); break;
case 2: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 2>(D, kh.block, kh.stride, info, k_step); break;
case 3: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 3>(D, kh.block, kh.stride, info, k_step); break;
@@ -7555,16 +7559,30 @@ struct FlashQKfp32 {
case 5: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 5>(D, kh.block, kh.stride, info, k_step); break;
case 6: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 6>(D, kh.block, kh.stride, info, k_step); break;
case 7: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 7>(D, kh.block, kh.stride, info, k_step); break;
#else
case 1: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
case 2: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
case 3: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
case 4: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
case 5: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
case 6: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
case 7: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
#endif
}
#else
if constexpr (D >= 128) {
switch (nq) {
case 1: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
case 2: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
case 3: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
case 4: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
case 5: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
case 6: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
case 7: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
}
} else {
switch (nq) {
case 1: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
case 2: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
case 3: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
case 4: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
case 5: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
case 6: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
case 7: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
}
}
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};