diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f4a031be..05f93f99 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -14837,11 +14837,12 @@ static void ggml_compute_forward_mul_mat_id( const int nth = params->nth; const enum ggml_type type = src0->type; + const enum ggml_type dequant_type = iqk_dequant_type((int)type, src1->ne[2]); const bool src1_cont = ggml_is_contiguous(src1); ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; - enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + enum ggml_type const vec_dot_type = type_traits[dequant_type].vec_dot_type; ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; int64_t const matmul_num_cols = type_traits[type].ncols; ggml_gemv_t const gemv = type_traits[type].gemv; @@ -14935,6 +14936,12 @@ static void ggml_compute_forward_mul_mat_id( continue; } + enum ggml_type this_dequant_type = iqk_dequant_type((int)type, cne1); + if (this_dequant_type != dequant_type) { + printf("Oops: %s (%d) and %s (%d)\n", ggml_type_name(dequant_type), (int)src1->ne[2], ggml_type_name(this_dequant_type), (int)cne1); + GGML_ABORT("Fatal error"); + } + const char * src0_cur = (const char *) src0->data + cur_a*nb02; const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; @@ -15105,7 +15112,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate( const enum ggml_type type = src0->type; - enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + enum ggml_type const dequant_type = iqk_dequant_type((int)type, src1->ne[1]); + enum ggml_type const vec_dot_type = type_traits[dequant_type].vec_dot_type; // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(type)); diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index dbbb5000..dfa0347e 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -861,127 +861,6 @@ static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const Data } } -struct DequantizerQ6K_AVX2 final : public BaseDequantizer { - DequantizerQ6K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} - inline void prepare(int i, int j) { - auto lbits1 = _mm256_loadu_si256((const __m256i *)x[i].ql + 2*j+0); - auto lbits2 = _mm256_loadu_si256((const __m256i *)x[i].ql + 2*j+1); - auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j); - bits.values[0] = _mm256_or_si256(_mm256_and_si256(lbits1, bits.ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh)); - bits.values[1] = _mm256_or_si256(_mm256_and_si256(lbits2, bits.ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); - bits.values[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), bits.ml), _mm256_and_si256(hbits, mh)); - bits.values[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), bits.ml), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh)); - } - inline void prepare_signed(int i, int j, __m256i * us) { - prepare(i, j); - for (int k = 0; k < 4; ++k) { - bits.values[k] = _mm256_add_epi8(bits.values[k], _mm256_set1_epi8(-32)); - us[k] = _mm256_sign_epi8(bits.values[k], bits.values[k]); - } - } - - const __m256i mh = _mm256_set1_epi8(0x10); - Q4Bits_AVX2 bits; -}; - -template -static void mul_mat_qY_K_q8_2_X4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n % QK_K == 0); - const int nb = n / QK_K; - - Q8 q8(info); - - Dequantizer deq(vx, bx); - - __m256 accd[nrc_y]; - __m256 scales[2]; - float d8[8*nrc_y]; - __m256i us[4]; - - uint8_t k_shuff[32] = {0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15}; - auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); - - for (int ix = 0; ix < nrc_x; ++ix) { - - for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); - - deq.new_row(ix); - - for (int i = 0; i < nb; ++i) { - - deq.d = GGML_FP16_TO_FP32(deq.x[i].d); - auto vd = _mm256_set1_ps(deq.d); - auto sc16 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)deq.x[i].scales)), shuff); - scales[0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(sc16)))); - scales[1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(sc16, 1)))); - //scales[0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(deq.x[i].scales+0))))); - //scales[1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(deq.x[i].scales+8))))); - //auto mins1 = _mm256_mul_ps(scales[0], _mm256_set1_ps(-32.f)); - //auto mins2 = _mm256_mul_ps(scales[1], _mm256_set1_ps(-32.f)); - for (int iy = 0; iy < nrc_y; ++iy) { - auto d4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d))); - auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d))); - auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16)); - _mm256_storeu_ps(d8 + 8*iy, dy); - //auto m4_1 = _mm_castsi128_ps(_mm_slli_epi16(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d+4))), 16)); - //auto m4_2 = _mm_castsi128_ps(_mm_slli_epi16(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d+4))), 16)); - //auto my1 = _mm256_set_m128(_mm_unpackhi_ps(m4_1, m4_1), _mm_unpacklo_ps(m4_1, m4_1)); // 0,0, 1,1, 2,2, 3,3 - //auto my2 = _mm256_set_m128(_mm_unpackhi_ps(m4_2, m4_2), _mm_unpacklo_ps(m4_2, m4_2)); // 4,4, 5,5, 6,6, 7,7 - //accd[iy] = _mm256_fmadd_ps(my1, mins1, accd[iy]); - //accd[iy] = _mm256_fmadd_ps(my2, mins2, accd[iy]); - } - - for (int j = 0; j < QK_K/128; ++j) { - - deq.prepare_signed(i, j, us); - - for (int iy = 0; iy < nrc_y; ++iy) { - auto qs = q8.y[iy][2*i+j].qs; -#ifdef HAVE_FANCY_SIMD - // 0...31 - auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+0), deq.bits.values[0])); - // 32...63 - auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[1], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+1), deq.bits.values[1])); - // 64...95 - auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[2], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+2), deq.bits.values[2])); - // 96...128 - auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[3], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+3), deq.bits.values[3])); - // 0...3, 32...35, 4....7, 36...39, 16...19, 48...51, 20...23, 52...56 + - // 8..11, 40...43, 12...15, 44...47, 24...27, 56...59, 28...31, 60...63 - // b0 b2 b0 b2 b1 b3 b1 b3 - sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); - // same as above + 64, so - // b4 b6, b4 b6 b5 b7 b5 b7 - sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); - // b0 b2 b4 b6 b1 b3 b5 b7 + - // b0 b2 b4 b6 b1 b3 b5 b7 - sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3)); -#else - auto sumi1 = _mm256_maddubs_epi16(deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0)); - auto sumi2 = _mm256_maddubs_epi16(deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1)); - auto sumi3 = _mm256_maddubs_epi16(deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2)); - auto sumi4 = _mm256_maddubs_epi16(deq.bits.values[3], _mm256_loadu_si256((const __m256i*)y.qs+3)); - sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); - sumi3 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); - sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3)); - sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1); -#endif - auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j); - auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4)); - accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]); - } - - } - - } - - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, hsum_float_8(accd[iy])); - } - - } -} - template static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); @@ -2197,7 +2076,6 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array(kernels); break; case GGML_TYPE_Q6_K: - //IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_2_X4_T, DequantizerQ6K_AVX2, kernels); set_functions(kernels); break; case GGML_TYPE_IQ4_XS: diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 6cce2478..3c915791 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -245,7 +245,8 @@ struct MulMat { case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; - case GGML_TYPE_Q6_K : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + //case GGML_TYPE_Q6_K : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_Q6_K : return GGML_TYPE_Q8_0_R8; default: break; } #else @@ -602,7 +603,12 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, MulMat mm; auto etypeA = ggml_type(typeA); - if (auto dequant_type = MulMat::is_dequant_better(etypeA, Ny); dequant_type != etypeA) { + //auto etypeB = ggml_type(typeB); + auto dequant_type = MulMat::is_dequant_better(etypeA, Ny); + //if (etypeB != GGML_TYPE_F32) { + // if (ith == 0) printf("%s: typeA = %s, typeB = %s, dequant_type = %s\n", __func__, ggml_type_name(etypeA), ggml_type_name(etypeB), ggml_type_name(dequant_type)); + //} + if (dequant_type != etypeA) { if (!MulMat::prepare(dequant_type, typeB, ne00, mm, Ny)) { return false; } @@ -617,9 +623,7 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, first_x *= num_rows; nrc_x *= num_rows; - auto type_size = ggml_type_size(dequant_type); - - size_t row_size_qx = ne00*type_size; + size_t row_size_qx = ggml_row_size(dequant_type, ne00); size_t row_size_qy = strideB; DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; @@ -631,7 +635,7 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; if (f.size() < row_size_qx*this_nrc_x) f.resize(row_size_qx*this_nrc_x); - if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f.data(), ne00, this_nrc_x)) { + if (!iqk_convert_repack(typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f.data(), ne00, this_nrc_x)) { GGML_ABORT("Fatal error"); } mm.mul_mat_NxM(ne00, f.data(), row_size_qx, this_info, this_nrc_x, Ny); @@ -685,9 +689,7 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n first_x *= num_rows; nrc_x *= num_rows; - auto type_size = ggml_type_size(dequant_type); - - size_t row_size_qx = ne00*type_size; + size_t row_size_qx = ggml_row_size(dequant_type, ne00); size_t row_size_qy = strideB; DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; @@ -701,10 +703,10 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n if (f.size() < 2*row_size_qx*this_nrc_x) f.resize(2*row_size_qx*this_nrc_x); auto Xu = f.data(); auto Xg = f.data() + row_size_qx*this_nrc_x; - if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)Aup + (first_x + ix)*strideA, strideA, Xu, ne00, this_nrc_x)) { + if (!iqk_convert_repack(typeA, ne00, (const char *)Aup + (first_x + ix)*strideA, strideA, Xu, ne00, this_nrc_x)) { GGML_ABORT("Fatal error"); } - if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) { + if (!iqk_convert_repack(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) { GGML_ABORT("Fatal error"); } mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, this_info, this_nrc_x, Ny, unary_op);