diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index a18c8690..8352a3c0 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3045,7 +3045,6 @@ template struct AccumT { } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, accm.result(acc[iy], iy)); - //s[iy*bs] = accm.result(acc[iy], iy); } } }; @@ -3212,6 +3211,35 @@ struct Q_Unpacker { } }; +struct Q8_0_x4_Unpacker { + using Sum4T = Sum4TypeQ80; + inline static int block_size() { return QK8_0; } + Q8_0_x4_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {} + + const char * cx_0; + const block_q8_0_x4 * x; + size_t bx; + + __m256i qx[4]; + + inline const __m256i* quants() const { return qx; } + + inline void set_row(int ix) { x = (const block_q8_0_x4 *)(cx_0 + ix*bx); } + + inline auto set_block_4(int i) { + auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x[i].d)); + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)x[i].qs + j); + } + return scales; + } + inline auto set_block(int i) { + auto q8 = (const block_q8_0 *)(x + i); + qx[0] = _mm256_loadu_si256((const __m256i *)q8->qs); + return GGML_FP16_TO_FP32(q8->d); + } +}; + struct Q8_0_Unpacker final : public Q_Unpacker { Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} using Sum4T = Sum4TypeQ80; @@ -7320,6 +7348,60 @@ struct FlashMS { } } + float smax = F16::reduce_max(vk); + if (smax == -INFINITY) { + std::memset(cache + k_step*j, 0, k_step*sizeof(float)); + need_scaling[j] = M[j] == -INFINITY ? 2 : 0; + return; + } + need_scaling[j] = 0; + if (smax > M[j]) { + if (M[j] > -INFINITY) { + float m = expf(M[j] - smax); + vms[j] = F16::set1(m); + need_scaling[j] = 1; + S[j] *= m; + } else { + need_scaling[j] = 2; + S[j] = 0; + } + M[j] = smax; + } + auto vm = F16::set1(M[j]); + for (int l = 0; l < k_step/F16::block_size; ++l) { + vk[l] = v_expf(F16::sub(vk[l], vm)); + F16::store(cache + k_step*j + F16::block_size*l, vk[l]); + } + S[j] += F16::reduce_add(vk); + } + inline void update_M_S(int j, F16::Data * vk, const char * mask) { + auto vzero = _mm256_set1_epi16(0); + auto vinf = _mm512_set1_ps(-INFINITY); + //for (int l = 0; l < k_step/F16::block_size; ++l) { + // auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero); + // vk[l] = _mm512_mask_blend_ps(m16, vinf, F16::load(cache + k_step*j + F16::block_size*l)); + //} + //if (softcap <= 0.0f) { + // for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, vk[l]); + //} else { + // auto v_softcap = F16::set1(softcap); + // for (int l = 0; l < k_step/F16::block_size; ++l) { + // vk[l] = F16::mul(v_softcap, v_tanh(F16::mul(vscale, vk[l]))); + // } + //} + if (softcap <= 0) { + for (int l = 0; l < k_step/F16::block_size; ++l) { + auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero); + vk[l] = _mm512_mask_mul_ps(vinf, m16, vscale, F16::load(cache + k_step*j + F16::block_size*l)); + } + } else { + auto v_softcap = F16::set1(softcap); + for (int l = 0; l < k_step/F16::block_size; ++l) { + auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mask + l), vzero); + vk[l] = _mm512_mask_mul_ps(vinf, m16, v_softcap, v_tanh(F16::mul(vscale, F16::load(cache + k_step*j + F16::block_size*l)))); + } + } + float smax = F16::reduce_max(vk); if (smax == -INFINITY) { std::memset(cache + k_step*j, 0, k_step*sizeof(float)); @@ -7636,15 +7718,27 @@ struct FlashQKfp32 { static_assert(q_step <= 8); if constexpr (std::is_same_v>) { DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; +#ifdef __aarch64__ mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); +#else + mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); +#endif } else if constexpr (std::is_same_v>) { DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; +#ifdef __aarch64__ mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); +#else + mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); +#endif } else if constexpr (std::is_same_v>) { DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; +#ifdef __aarch64__ mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); +#else + mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); +#endif } else { GGML_ASSERT(false); @@ -7657,7 +7751,7 @@ struct FlashQKfp32 { #else F16::Data vk[k_step/F16::block_size]; for (int j = 0; j < q_step; ++j) { - fms.update_M_S(j, vk); + fms.update_M_S(j, vk, mask + stride_m*j); } #endif } @@ -7668,6 +7762,7 @@ struct FlashQKfp32 { if constexpr (std::is_same_v>) { DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; switch (nq) { +#ifdef __aarch64__ case 1: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 2: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 3: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; @@ -7675,11 +7770,21 @@ struct FlashQKfp32 { case 5: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 6: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 7: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; +#else + case 1: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; +#endif } } else if constexpr (std::is_same_v>) { DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; switch (nq) { +#ifdef __aarch64__ case 1: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 2: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 3: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; @@ -7687,11 +7792,21 @@ struct FlashQKfp32 { case 5: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 6: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; case 7: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; +#else + case 1: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; +#endif } } else if constexpr (std::is_same_v>) { DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; switch (nq) { +#ifdef __aarch64__ case 1: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; case 2: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; case 3: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; @@ -7699,6 +7814,15 @@ struct FlashQKfp32 { case 5: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; case 6: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; case 7: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; +#else + case 1: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; +#endif } } else { @@ -7712,7 +7836,7 @@ struct FlashQKfp32 { #else F16::Data vk[k_step/F16::block_size]; for (int j = 0; j < nq; ++j) { - fms.update_M_S(j, vk); + fms.update_M_S(j, vk, mask + stride_m*j); } #endif }