diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index e6295b9d..275858bc 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -2372,12 +2372,15 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri const int nb = k / QK_K; const bool requantize = true; - uint8_t L[QK_K]; + uint8_t L[QK_K], L1[QK_K]; uint8_t Laux[16]; float mins[QK_K/16]; + float mins1[QK_K/16]; float scales[QK_K/16]; + float scales1[QK_K/16]; float sw[QK_K/16]; float weight[16]; + float xaux[16]; uint8_t Ls[QK_K/16], Lm[QK_K/16]; for (int i = 0; i < nb; i++) { @@ -2390,11 +2393,33 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]); for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l]; scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); + for (int l = 0; l < 16; ++l) xaux[l] = -x[16*j + l]; + scales1[j] = make_qkx3_quants(16, 3, xaux, weight, L1 + 16*j, &mins1[j], Laux, -0.9f, 0.05f, 36, false); + } + + float mse = 0, mse1 = 0; + for (int j = 0; j < QK_K/16; ++j) { + const float * restrict qw = quant_weights + QK_K * i + 16*j; + for (int l = 0; l < 16; ++l) { + float w = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]); + float diff = scales[j]*L[16*j+l] - mins[j] - x[16*j + l]; + mse += w*diff*diff; + diff = -scales1[j]*L1[16*j+l] + mins1[j] - x[16*j + l]; + mse1 += w*diff*diff; + } } float dm, mm; - dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw); - mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw); + if (mse <= mse1) { + dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw); + mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw); + } else { + dm = -make_qp_quants(QK_K/16, 15, scales1, Ls, sw); + mm = -make_qp_quants(QK_K/16, 15, mins1, Lm, sw); + if (!requantize) { + memcpy(L, L1, QK_K); + } + } y[i].d = GGML_FP32_TO_FP16(dm); y[i].dmin = GGML_FP32_TO_FP16(mm);