diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp index 17d2dad3..cc8863ce 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp @@ -1690,6 +1690,73 @@ static void mul_mat_q8_1_r8_q8_2(int n, const void * vx, size_t bx, const DataIn } } +void iqk_convert_q80_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + static_assert(QK4_0 == QK8_0); + GGML_ASSERT(n%QK4_0 == 0); + GGML_ASSERT(nrc_x%8 == 0); + + const int nb = n/QK4_0; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_q8_0 * x8[8]; + + uint32_t block[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + + for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)((const char *)vx + (ix + k)*bx); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + y[i].d[k] = x8[k][i].d; + _mm256_storeu_si256((__m256i *)block, _mm256_loadu_si256((const __m256i *)x8[k][i].qs)); + auto qs = (uint32_t *)y[i].qs; + for (int l = 0; l < 4; ++l) { + qs[8*l + k + 0] = block[l + 0]; + qs[8*l + k + 32] = block[l + 4]; + } + } + } + y += nb; + } +} + +void iqk_convert_q40_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + static_assert(QK4_0 == QK8_0); + GGML_ASSERT(n%QK4_0 == 0); + GGML_ASSERT(nrc_x%8 == 0); + + const int nb = n/QK4_0; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_q4_0 * x8[8]; + + uint32_t block[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + + for (int k = 0; k < 8; ++k) x8[k] = (const block_q4_0 *)((const char *)vx + (ix + k)*bx); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + y[i].d[k] = x8[k][i].d; + auto bits = _mm_loadu_si128((const __m128i *)x8[k][i].qs); + auto val = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(bits, 4), bits), _mm256_set1_epi8(0xf)); + val = _mm256_add_epi8(val, _mm256_set1_epi8(-8)); + _mm256_storeu_si256((__m256i *)block, val); + auto qs = (uint32_t *)y[i].qs; + for (int l = 0; l < 4; ++l) { + qs[8*l + k + 0] = block[l + 0]; + qs[8*l + k + 32] = block[l + 4]; + } + } + } + y += nb; + } +} + template void set_functions(std::array& funcs) { if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { @@ -1713,6 +1780,15 @@ template void set_functions(std::array& kernels, mul_mat_t& func16) { if (ne00%QK8_0 != 0) return false; diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.h b/ggml/src/iqk/iqk_gemm_legacy_quants.h index a472d9bb..179e806a 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.h +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.h @@ -11,4 +11,6 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ6_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_Q4_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; default: break; } #else @@ -403,19 +405,19 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: return iqk_dequantize_ktquants(typeA, n, vx, bx, vy, stride_y, nrc_x); - //case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_0: //case GGML_TYPE_Q4_1: //case GGML_TYPE_Q5_0: //case GGML_TYPE_Q5_1: //case GGML_TYPE_Q6_0: - //case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_0: //case GGML_TYPE_IQ4_NL: //case GGML_TYPE_Q4_0_R8: //case GGML_TYPE_Q5_0_R4: //case GGML_TYPE_Q6_0_R4: //case GGML_TYPE_Q8_0_R8: //case GGML_TYPE_IQ4_NL_R4: - // return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16); + return iqk_convert_legacy_quants_q8_r8(typeA, n, vx, bx, vy, nrc_x); case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: //case GGML_TYPE_IQ1_S_R4: