New iq4_kt: AVX2 dot product finally works

We get 13.6 t/s vs 8.4 t/s with the f16 trellis and f32 arithmetic.
Still somewhat slower than other quants, but no longer pathetic.
This commit is contained in:
Iwan Kawrakow
2025-06-07 19:12:09 +03:00
parent 36fba1fff2
commit 78411343cc

View File

@@ -402,92 +402,6 @@ void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
}
}
/*
template <int nrc_y>
void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
const int nb = n/QK_K;
constexpr int kNumGroups = 64;
Trellis3 trellis;
constexpr int k_acc = nrc_y;
__m256 accd[k_acc];
const block_q8_2_x4 * y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
y[iy] = (const block_q8_2_x4 *)info.src1_row(iy);
}
__m256i xv[8];
const block_iq4_kt * x8[8];
float dkt[8];
int32_t ls[8];
uint32_t idx0[8], idx[8];
union { float f; uint32_t u; } bf16_helper;
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) {
const float * dptr = (const float *)((const char*)vx + (ix+k)*bx);
dkt[k] = dptr[0];
x8[k] = (const block_iq4_kt *)(dptr + 2);
}
auto vd = _mm256_loadu_ps(dkt);
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
for (int ib = 0; ib < QK_K/32; ++ib) {
for (int k = 0; k < 8; ++k) {
ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64;
idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096;
}
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i *)ls)));
auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-126.f));
int shift1 = 8 - 4*(ib/4);
for (int j = 0; j < 8; ++j) {
for (int k = 0; k < 8; ++k) {
const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8);
const uint8_t * qh = ql + kNumGroups;
const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j);
idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k];
}
xv[j] = trellis.next32<true>(idx);
}
for (int iy = 0; iy < nrc_y; ++iy) {
const auto& yb = y[iy][2*i+ib/4];
int i4 = ib%4;
auto vy8 = _mm_loadu_si128((const __m128i *)yb.qs + 2*i4+0);
auto vy = MM256_SET_M128I(vy8, vy8);
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, xv[0], _mm256_shuffle_epi32(vy, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, xv[1], _mm256_shuffle_epi32(vy, 0x50));
sumi = _mm256_dpbusd_epi32(sumi, xv[2], _mm256_shuffle_epi32(vy, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, xv[3], _mm256_shuffle_epi32(vy, 0xff));
vy8 = _mm_loadu_si128((const __m128i *)yb.qs + 2*i4+1);
vy = MM256_SET_M128I(vy8, vy8);
sumi = _mm256_dpbusd_epi32(sumi, xv[4], _mm256_shuffle_epi32(vy, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, xv[5], _mm256_shuffle_epi32(vy, 0x50));
sumi = _mm256_dpbusd_epi32(sumi, xv[6], _mm256_shuffle_epi32(vy, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, xv[7], _mm256_shuffle_epi32(vy, 0xff));
bf16_helper.u = yb.d[i4] << 16;
auto d8 = _mm256_mul_ps(scales, _mm256_set1_ps(bf16_helper.f));
accd[iy] = _mm256_fmadd_ps(d8, _mm256_cvtepi32_ps(sumi), accd[iy]);
bf16_helper.u = yb.d[i4+4] << 16;
accd[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(bf16_helper.f), accd[iy]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, accd[iy]);
}
}
}
*/
void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
const int nb = n/QK_K;
@@ -573,11 +487,12 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
auto compute_dot = [&dot, &xv] (const int8_t * y) {
for (int k = 0; k < 4; ++k) {
auto yv = _mm256_loadu_si256((const __m256i *)y + k);
dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv);
//dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv);
dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k]));
}
};
auto m126 = _mm256_set1_ps(-126.f);
//auto m126 = _mm256_set1_ps(-126.f);
for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char*)vx + ix*bx);
@@ -609,30 +524,18 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
}
}
// sum[d4 * (x_i - 126) * d8 * y_i] => d4*d8*sum[x_i*y_i] - 126*d4*(d8*sum[y_i] -> m8)
// d4*d8*sum[x_i*y_i] - 126*d4*m8
for (int i128 = 0; i128 < 2; ++i128) {
for (int k = 0; k < 4; ++k) xv[k] = trellis.next32<true>(values + 32*i128 + 8*k);
//auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)y[0][2*i+i128].d)), 16));
//auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy));
//auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
//m8 = _mm256_mul_ps(m8, _mm256_set1_ps(-126.f));
//for (int k = 0; k < 4; ++k) {
// xv[k] = trellis.next32<true>(values + 32*i128 + 8*k);
// auto yv = _mm256_loadu_si256((const __m256i *)y[0][2*i+i128].qs + k);
// dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv);
//}
//accd[0] = _mm256_fmadd_ps(_mm256_mul_ps(scales[i128], d8), sum_4(), accd[0]);
//accd[0] = _mm256_fmadd_ps(scales[i128], m8, accd[0]);
//for (int k = 0; k < 4; ++k) xv[k] = trellis.next32<true>(values + 32*i128 + 8*k);
for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k);
for (int iy = 0; iy < nrc_y; ++iy) {
const block_q8_2_x4& yb = y[iy][2*i+i128];
auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16));
dy = _mm256_mul_ps(scales[i128], dy);
auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy));
auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
//auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
compute_dot(yb.qs);
accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]);
accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]);
//accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]);
}
}
}