Fix q8_0 KV cache when not using FA - NEON

This commit is contained in:
Iwan Kawrakow
2025-01-15 11:36:01 +01:00
parent ad78678bb9
commit 0ecc20e481

View File

@@ -12036,35 +12036,35 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
break;
case GGML_TYPE_Q4_0:
MulMat::set_functions<DequantizerQ40>(m);
expected_Btype = GGML_TYPE_Q8_0;
expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_Q4_1:
MulMat::set_functions<DequantizerQ41>(m);
expected_Btype = GGML_TYPE_Q8_1;
expected_Btype = GGML_TYPE_Q8_1_X4;
break;
case GGML_TYPE_Q5_0:
MulMat::set_functions<DequantizerQ50>(m);
expected_Btype = GGML_TYPE_Q8_0;
expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_Q5_1:
MulMat::set_functions<DequantizerQ51>(m);
expected_Btype = GGML_TYPE_Q8_1;
expected_Btype = GGML_TYPE_Q8_1_X4;
break;
case GGML_TYPE_Q6_0:
MulMat::set_functions<DequantizerQ60>(m);
expected_Btype = GGML_TYPE_Q8_0;
expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_Q8_0:
MulMat::set_functions<DequantizerQ80>(m);
expected_Btype = GGML_TYPE_Q8_0;
expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_IQ4_NL:
MulMat::set_functions<DequantizerIQ4NL>(m);
expected_Btype = GGML_TYPE_Q8_0;
expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_IQ4_NL_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0;
expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_IQ4_XS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r4_q8_k);
@@ -12141,19 +12141,19 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
break;
case GGML_TYPE_Q4_0_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q4_0_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0;
expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_Q5_0_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q5_0_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0;
expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_Q6_0_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q6_0_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0;
expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_Q8_0_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_0_r4_q8_0);
expected_Btype = GGML_TYPE_Q8_0;
expected_Btype = GGML_TYPE_Q8_0_X4;
break;
default:
return false;
@@ -12461,9 +12461,9 @@ struct HelperQ80 final : public BaseHelper<step> {
int j = F16::block_size*i;
auto dl = (const block_q8_0 *)Base::lblock(l1) + j/QK8_0;
#ifdef __aarch64__
const float16_t * d = (const float16_t *)dl->d;
auto vd = F16::set1(d[ii]);
auto qs = vld1_s8_x2(dl->qs + 32*ii + j%32);
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
int ii = j%QK8_0;
auto qs = vld1_s8_x2(dl->qs + ii);
v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
#else
@@ -13210,7 +13210,7 @@ struct FlashQKfp32 {
}
else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80_x4, nq);
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
#else
if constexpr (D >= 128) {
#ifdef HAVE_FANCY_SIMD