diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 20a9831b..f7c83bed 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -1971,7 +1971,7 @@ static float make_qkx2_quants(int n, int nmax, const float * restrict x, const f sum_x += w * x[i]; } if (min > 0) min = 0; - if (max == min) { + if (max - min < 1e-10f) { for (int i = 0; i < n; ++i) L[i] = 0; *the_min = -min; return 0.f; @@ -2218,7 +2218,7 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f if (min > 0) { min = 0; } - if (max <= min) { + if (max - min < 1e-10f) { memset(L, 0, n); *the_min = -min; return 0.f; @@ -2340,7 +2340,7 @@ static float make_qp_quants(int n, int nmax, const float * restrict x, uint8_t * for (int i = 0; i < n; ++i) { max = MAX(max, x[i]); } - if (!max) { // all zero + if (max < 1e-16f) { // all zero for (int i = 0; i < n; ++i) { L[i] = 0; } return 0.f; } @@ -2733,6 +2733,10 @@ void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, in float av_x = sqrtf(sum_x2/32); for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false); + if (isnan(scales[j])) { + printf("Oops: NaN scale\n"); + GGML_ABORT("Fatal error"); + } float scale = scales[j]; if (scale > max_scale) { max_scale = scale; @@ -2846,10 +2850,18 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri for (int l = 0; l < 32; ++l) sumw += weights[l]; sw[j] = sumw; scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); + if (isnan(scales[j])) { + printf("%s: got NaN scale\n", __func__); + GGML_ABORT("Fatal error"); + } } float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw); float m_block = make_qp_quants(QK_K/32, 63, mins, Lm, sw); + if (isnan(d_block) || isnan(m_block)) { + printf("%s: d_block = %g, m_block = %g\n", __func__, (double)d_block, (double)m_block); + GGML_ABORT("Fatal error"); + } for (int j = 0; j < QK_K/32; ++j) { uint8_t ls = Ls[j]; uint8_t lm = Lm[j];