Fix q8_0 KV cache when not using FA - NEON

This commit is contained in:
Iwan Kawrakow
2025-01-15 11:36:01 +01:00
parent ad78678bb9
commit 0ecc20e481

View File

@@ -12036,35 +12036,35 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
break; break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
MulMat::set_functions<DequantizerQ40>(m); MulMat::set_functions<DequantizerQ40>(m);
expected_Btype = GGML_TYPE_Q8_0; expected_Btype = GGML_TYPE_Q8_0_X4;
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
MulMat::set_functions<DequantizerQ41>(m); MulMat::set_functions<DequantizerQ41>(m);
expected_Btype = GGML_TYPE_Q8_1; expected_Btype = GGML_TYPE_Q8_1_X4;
break; break;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
MulMat::set_functions<DequantizerQ50>(m); MulMat::set_functions<DequantizerQ50>(m);
expected_Btype = GGML_TYPE_Q8_0; expected_Btype = GGML_TYPE_Q8_0_X4;
break; break;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
MulMat::set_functions<DequantizerQ51>(m); MulMat::set_functions<DequantizerQ51>(m);
expected_Btype = GGML_TYPE_Q8_1; expected_Btype = GGML_TYPE_Q8_1_X4;
break; break;
case GGML_TYPE_Q6_0: case GGML_TYPE_Q6_0:
MulMat::set_functions<DequantizerQ60>(m); MulMat::set_functions<DequantizerQ60>(m);
expected_Btype = GGML_TYPE_Q8_0; expected_Btype = GGML_TYPE_Q8_0_X4;
break; break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
MulMat::set_functions<DequantizerQ80>(m); MulMat::set_functions<DequantizerQ80>(m);
expected_Btype = GGML_TYPE_Q8_0; expected_Btype = GGML_TYPE_Q8_0_X4;
break; break;
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
MulMat::set_functions<DequantizerIQ4NL>(m); MulMat::set_functions<DequantizerIQ4NL>(m);
expected_Btype = GGML_TYPE_Q8_0; expected_Btype = GGML_TYPE_Q8_0_X4;
break; break;
case GGML_TYPE_IQ4_NL_R4: case GGML_TYPE_IQ4_NL_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer); SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0; expected_Btype = GGML_TYPE_Q8_0_X4;
break; break;
case GGML_TYPE_IQ4_XS_R4: case GGML_TYPE_IQ4_XS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r4_q8_k); SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r4_q8_k);
@@ -12141,19 +12141,19 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
break; break;
case GGML_TYPE_Q4_0_R4: case GGML_TYPE_Q4_0_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q4_0_R4_Dequantizer); SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q4_0_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0; expected_Btype = GGML_TYPE_Q8_0_X4;
break; break;
case GGML_TYPE_Q5_0_R4: case GGML_TYPE_Q5_0_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q5_0_R4_Dequantizer); SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q5_0_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0; expected_Btype = GGML_TYPE_Q8_0_X4;
break; break;
case GGML_TYPE_Q6_0_R4: case GGML_TYPE_Q6_0_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q6_0_R4_Dequantizer); SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q6_0_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0; expected_Btype = GGML_TYPE_Q8_0_X4;
break; break;
case GGML_TYPE_Q8_0_R4: case GGML_TYPE_Q8_0_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_0_r4_q8_0); SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_0_r4_q8_0);
expected_Btype = GGML_TYPE_Q8_0; expected_Btype = GGML_TYPE_Q8_0_X4;
break; break;
default: default:
return false; return false;
@@ -12461,9 +12461,9 @@ struct HelperQ80 final : public BaseHelper<step> {
int j = F16::block_size*i; int j = F16::block_size*i;
auto dl = (const block_q8_0 *)Base::lblock(l1) + j/QK8_0; auto dl = (const block_q8_0 *)Base::lblock(l1) + j/QK8_0;
#ifdef __aarch64__ #ifdef __aarch64__
const float16_t * d = (const float16_t *)dl->d; auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
auto vd = F16::set1(d[ii]); int ii = j%QK8_0;
auto qs = vld1_s8_x2(dl->qs + 32*ii + j%32); auto qs = vld1_s8_x2(dl->qs + ii);
v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0]))); v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1]))); v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
#else #else
@@ -13210,7 +13210,7 @@ struct FlashQKfp32 {
} }
else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) { else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
#ifdef __aarch64__ #ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80_x4, nq); MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
#else #else
if constexpr (D >= 128) { if constexpr (D >= 128) {
#ifdef HAVE_FANCY_SIMD #ifdef HAVE_FANCY_SIMD