diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index a532d18e..1245f4a3 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -2143,12 +2143,9 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const const auto m1_16 = _mm256_set1_epi16(1); #endif - //const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx); - const block_iq1_bn * x; const char * cx0 = (const char *)vx; float scale; - //template > float scale; for (int ix = 0; ix < nrc_x; ++ix) { @@ -5894,7 +5891,7 @@ struct DequantizerIQ1BN { } }; -template +template static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_IQ1BN; @@ -5904,11 +5901,17 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn int32x4_t accd[nrc_y]; int8x16x4_t v1, v2; - const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx); + float scale; for (int ix = 0; ix < nrc_x; ++ix) { - x = (const block_iq1_bn *)((const char *)vx + ix*bx); + const char * cx = ((const char *)vx + ix*bx); + if constexpr (is_iq1_tn) { + scale = GGML_FP16_TO_FP32(*(const ggml_half *)cx); + cx += sizeof(ggml_half); + } + + const block_iq1_bn * x = (const block_iq1_bn *)cx; if constexpr (nrc_y == 1) { int32x4_t acc[4] = {}; @@ -5962,7 +5965,11 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy])))); + if constexpr (is_iq1_tn) { + info.store(ix, iy, scale * vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy])))); + } else { + info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy])))); + } } } @@ -6159,14 +6166,25 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { MulMat::set_functions(m); break; case GGML_TYPE_IQ1_BN: - m.funcs[0] = mul_mat_iq1bn_q8_K64<1>; - m.funcs[1] = mul_mat_iq1bn_q8_K64<2>; - m.funcs[2] = mul_mat_iq1bn_q8_K64<3>; - m.funcs[3] = mul_mat_iq1bn_q8_K64<4>; - m.funcs[4] = mul_mat_iq1bn_q8_K64<5>; - m.funcs[5] = mul_mat_iq1bn_q8_K64<6>; - m.funcs[6] = mul_mat_iq1bn_q8_K64<7>; - m.funcs[7] = mul_mat_iq1bn_q8_K64<8>; + m.funcs[0] = mul_mat_iq1bn_q8_K64<1, false>; + m.funcs[1] = mul_mat_iq1bn_q8_K64<2, false>; + m.funcs[2] = mul_mat_iq1bn_q8_K64<3, false>; + m.funcs[3] = mul_mat_iq1bn_q8_K64<4, false>; + m.funcs[4] = mul_mat_iq1bn_q8_K64<5, false>; + m.funcs[5] = mul_mat_iq1bn_q8_K64<6, false>; + m.funcs[6] = mul_mat_iq1bn_q8_K64<7, false>; + m.funcs[7] = mul_mat_iq1bn_q8_K64<8, false>; + expected_Btype = GGML_TYPE_Q8_K64; + break; + case GGML_TYPE_IQ1_TN: + m.funcs[0] = mul_mat_iq1bn_q8_K64<1, true>; + m.funcs[1] = mul_mat_iq1bn_q8_K64<2, true>; + m.funcs[2] = mul_mat_iq1bn_q8_K64<3, true>; + m.funcs[3] = mul_mat_iq1bn_q8_K64<4, true>; + m.funcs[4] = mul_mat_iq1bn_q8_K64<5, true>; + m.funcs[5] = mul_mat_iq1bn_q8_K64<6, true>; + m.funcs[6] = mul_mat_iq1bn_q8_K64<7, true>; + m.funcs[7] = mul_mat_iq1bn_q8_K64<8, true>; expected_Btype = GGML_TYPE_Q8_K64; break; case GGML_TYPE_IQ2_BN: