mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-01 03:41:53 +00:00
New iq4_kt: AVX2 dot product finally works
We get 13.6 t/s vs 8.4 t/s with the f16 trellis and f32 arithmetic. Still somewhat slower than other quants, but no longer pathetic.
This commit is contained in:
@@ -402,92 +402,6 @@ void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
constexpr int kNumGroups = 64;
|
||||
|
||||
Trellis3 trellis;
|
||||
|
||||
constexpr int k_acc = nrc_y;
|
||||
|
||||
__m256 accd[k_acc];
|
||||
const block_q8_2_x4 * y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
y[iy] = (const block_q8_2_x4 *)info.src1_row(iy);
|
||||
}
|
||||
|
||||
__m256i xv[8];
|
||||
|
||||
const block_iq4_kt * x8[8];
|
||||
float dkt[8];
|
||||
int32_t ls[8];
|
||||
uint32_t idx0[8], idx[8];
|
||||
|
||||
union { float f; uint32_t u; } bf16_helper;
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
const float * dptr = (const float *)((const char*)vx + (ix+k)*bx);
|
||||
dkt[k] = dptr[0];
|
||||
x8[k] = (const block_iq4_kt *)(dptr + 2);
|
||||
}
|
||||
auto vd = _mm256_loadu_ps(dkt);
|
||||
|
||||
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64;
|
||||
idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096;
|
||||
}
|
||||
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i *)ls)));
|
||||
auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-126.f));
|
||||
int shift1 = 8 - 4*(ib/4);
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8);
|
||||
const uint8_t * qh = ql + kNumGroups;
|
||||
const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j);
|
||||
idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k];
|
||||
}
|
||||
xv[j] = trellis.next32<true>(idx);
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const auto& yb = y[iy][2*i+ib/4];
|
||||
int i4 = ib%4;
|
||||
auto vy8 = _mm_loadu_si128((const __m128i *)yb.qs + 2*i4+0);
|
||||
auto vy = MM256_SET_M128I(vy8, vy8);
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[0], _mm256_shuffle_epi32(vy, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[1], _mm256_shuffle_epi32(vy, 0x50));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[2], _mm256_shuffle_epi32(vy, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[3], _mm256_shuffle_epi32(vy, 0xff));
|
||||
vy8 = _mm_loadu_si128((const __m128i *)yb.qs + 2*i4+1);
|
||||
vy = MM256_SET_M128I(vy8, vy8);
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[4], _mm256_shuffle_epi32(vy, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[5], _mm256_shuffle_epi32(vy, 0x50));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[6], _mm256_shuffle_epi32(vy, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[7], _mm256_shuffle_epi32(vy, 0xff));
|
||||
bf16_helper.u = yb.d[i4] << 16;
|
||||
auto d8 = _mm256_mul_ps(scales, _mm256_set1_ps(bf16_helper.f));
|
||||
accd[iy] = _mm256_fmadd_ps(d8, _mm256_cvtepi32_ps(sumi), accd[iy]);
|
||||
bf16_helper.u = yb.d[i4+4] << 16;
|
||||
accd[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(bf16_helper.f), accd[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, accd[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) {
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
@@ -573,11 +487,12 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
|
||||
auto compute_dot = [&dot, &xv] (const int8_t * y) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto yv = _mm256_loadu_si256((const __m256i *)y + k);
|
||||
dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv);
|
||||
//dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv);
|
||||
dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k]));
|
||||
}
|
||||
};
|
||||
|
||||
auto m126 = _mm256_set1_ps(-126.f);
|
||||
//auto m126 = _mm256_set1_ps(-126.f);
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
@@ -609,30 +524,18 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
|
||||
values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
|
||||
}
|
||||
}
|
||||
// sum[d4 * (x_i - 126) * d8 * y_i] => d4*d8*sum[x_i*y_i] - 126*d4*(d8*sum[y_i] -> m8)
|
||||
// d4*d8*sum[x_i*y_i] - 126*d4*m8
|
||||
for (int i128 = 0; i128 < 2; ++i128) {
|
||||
for (int k = 0; k < 4; ++k) xv[k] = trellis.next32<true>(values + 32*i128 + 8*k);
|
||||
//auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)y[0][2*i+i128].d)), 16));
|
||||
//auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy));
|
||||
//auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
|
||||
//m8 = _mm256_mul_ps(m8, _mm256_set1_ps(-126.f));
|
||||
//for (int k = 0; k < 4; ++k) {
|
||||
// xv[k] = trellis.next32<true>(values + 32*i128 + 8*k);
|
||||
// auto yv = _mm256_loadu_si256((const __m256i *)y[0][2*i+i128].qs + k);
|
||||
// dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv);
|
||||
//}
|
||||
//accd[0] = _mm256_fmadd_ps(_mm256_mul_ps(scales[i128], d8), sum_4(), accd[0]);
|
||||
//accd[0] = _mm256_fmadd_ps(scales[i128], m8, accd[0]);
|
||||
//for (int k = 0; k < 4; ++k) xv[k] = trellis.next32<true>(values + 32*i128 + 8*k);
|
||||
for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const block_q8_2_x4& yb = y[iy][2*i+i128];
|
||||
auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16));
|
||||
dy = _mm256_mul_ps(scales[i128], dy);
|
||||
auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy));
|
||||
auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
|
||||
//auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
|
||||
compute_dot(yb.qs);
|
||||
accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]);
|
||||
accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]);
|
||||
//accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user