mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-03 13:04:59 +00:00
iq2_tn: slightly faster PP (#43)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -696,6 +696,9 @@ struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
|
||||
DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, [[maybe_unused]] __m512i * scales) {
|
||||
new_block(i);
|
||||
}
|
||||
inline void new_block(int i) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
bits.prepare(x[i].qs);
|
||||
}
|
||||
@@ -1125,6 +1128,64 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
const int nb = n / QK_K;
|
||||
|
||||
Q8<nrc_y> q8(info);
|
||||
|
||||
DequantizerIQ2TN deq1(vx, bx), deq2(vx, bx);
|
||||
|
||||
__m512 accd[2*nrc_y];
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ix += 2) {
|
||||
|
||||
for (int iy = 0; iy < 2*nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
|
||||
|
||||
deq1.new_row(ix+0);
|
||||
deq2.new_row(ix+1);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
deq1.new_block(i);
|
||||
deq2.new_block(i);
|
||||
float d = 0.5f*(deq1.d + deq2.d); // The scale is supposed to be per per tensor, so we can use the same scale for both rows
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sumi_scales_256 = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i));
|
||||
auto sumi_scales_512 = _mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales_256, 0);
|
||||
auto q8q = q8.load_quants64(iy, i, 0);
|
||||
auto sumi_1 = _mm512_dpbusd_epi32(sumi_scales_512, deq1.bits.values[0], q8q);
|
||||
auto sumi_2 = _mm512_dpbusd_epi32(sumi_scales_512, deq2.bits.values[0], q8q);
|
||||
q8q = q8.load_quants64(iy, i, 1);
|
||||
sumi_1 = _mm512_dpbusd_epi32(sumi_1, deq1.bits.values[1], q8q);
|
||||
sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[1], q8q);
|
||||
q8q = q8.load_quants64(iy, i, 2);
|
||||
sumi_1 = _mm512_dpbusd_epi32(sumi_1, deq1.bits.values[2], q8q);
|
||||
sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[2], q8q);
|
||||
q8q = q8.load_quants64(iy, i, 3);
|
||||
sumi_1 = _mm512_dpbusd_epi32(sumi_1, deq1.bits.values[3], q8q);
|
||||
sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[3], q8q);
|
||||
// The scale is supposed to be per per tensor, so we can use the same scale
|
||||
auto vd = _mm512_set1_ps(d*q8.scale(iy, i));
|
||||
accd[2*iy+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
|
||||
accd[2*iy+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
|
||||
// Leaving this here just in case ternary models start using per row scales
|
||||
//accd[2*iy+0] = _mm512_fmadd_ps(_mm512_set1_ps(deq1.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
|
||||
//accd[2*iy+1] = _mm512_fmadd_ps(_mm512_set1_ps(deq2.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix+0, iy, _mm512_reduce_add_ps(accd[2*iy+0]));
|
||||
info.store(ix+1, iy, _mm512_reduce_add_ps(accd[2*iy+1]));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dequantizer, int nrc_y>
|
||||
static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -3589,7 +3650,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
case GGML_TYPE_IQ2_TN:
|
||||
assert (ne00 % QK_K == 0);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
MulMat::set_functions<DequantizerIQ2TN>(mm);
|
||||
//MulMat::set_functions<DequantizerIQ2TN>(mm);
|
||||
mm.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<DequantizerIQ2TN>;
|
||||
//mm.funcs[0] = mul_mat_iq2tn_q8_K_AVX512<1>;
|
||||
mm.funcs[1] = mul_mat_iq2tn_q8_K_AVX512<2>;
|
||||
mm.funcs[2] = mul_mat_iq2tn_q8_K_AVX512<3>;
|
||||
mm.funcs[3] = mul_mat_iq2tn_q8_K_AVX512<4>;
|
||||
mm.funcs[4] = mul_mat_iq2tn_q8_K_AVX512<5>;
|
||||
mm.funcs[5] = mul_mat_iq2tn_q8_K_AVX512<6>;
|
||||
mm.funcs[6] = mul_mat_iq2tn_q8_K_AVX512<7>;
|
||||
mm.funcs[7] = mul_mat_iq2tn_q8_K_AVX512<8>;
|
||||
#else
|
||||
mm.funcs[0] = mul_mat_iq2tn_q8_K<1>;
|
||||
mm.funcs[1] = mul_mat_iq2tn_q8_K<2>;
|
||||
|
||||
Reference in New Issue
Block a user