mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
Something is not working with the AVX2 dot product
This commit is contained in:
@@ -120,11 +120,13 @@ struct Trellis3 {
|
||||
auto i8 = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), v8);
|
||||
return _mm256_cvtepi32_ps(i8);
|
||||
}
|
||||
template <bool is_unsigned = false>
|
||||
inline __m256i next32(const uint32_t * val) const {
|
||||
const __m256i offset = is_unsigned ? _mm256_setzero_si256() : _mm256_set1_epi32(-126);
|
||||
__m256i aux[4];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto i8 = _mm256_and_si256(next8(val[2*i+0], val[2*i+1]), _mm256_set1_epi32(0x3f3f3f3f));
|
||||
aux[i] = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), i8);
|
||||
aux[i] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8);
|
||||
}
|
||||
aux[0] = _mm256_packs_epi32(aux[0], aux[1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
||||
aux[2] = _mm256_packs_epi32(aux[2], aux[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
||||
@@ -352,20 +354,6 @@ void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
}
|
||||
}
|
||||
|
||||
// Q8_0 repacking:
|
||||
// for (int ib = 0; ib < nblock; ++ib) {
|
||||
// for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d;
|
||||
// for (int l = 0; l < 4; ++l) {
|
||||
// for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
|
||||
// y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
|
||||
// y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
|
||||
// as uint32_t
|
||||
// y[ib].qs[8*l+k+ 0] = x8[k][ib].qs[l+ 0];
|
||||
// y[ib].qs[8*l+k+32] = x8[k][ib].qs[l+16];
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
@@ -397,46 +385,6 @@ void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
|
||||
}
|
||||
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i *)ls)));
|
||||
_mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
|
||||
//for (int k = 0; k < 8; ++k) {
|
||||
// auto shb = x8[k][i].qs;
|
||||
// const uint8_t * ql = (const uint8_t *)(shb + 8);
|
||||
// const uint8_t * qh = ql + kNumGroups;
|
||||
// for (int ib = 0; ib < 4; ++ib) {
|
||||
// uint32_t offset1 = ((shb[ib+0] & 1) << 15) + 4096;
|
||||
// uint32_t offset2 = ((shb[ib+4] & 1) << 15) + 4096;
|
||||
// for (int j = 0; j < 4; ++j) {
|
||||
// const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
|
||||
// const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
|
||||
// idx[64*ib + 16*j + k ] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1;
|
||||
// idx[64*ib + 16*j + k + 8] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1;
|
||||
// idx[64*ib + 16*j + k + 256] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2;
|
||||
// idx[64*ib + 16*j + k + 264] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2;
|
||||
// //uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1;
|
||||
// //uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2;
|
||||
// //uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1;
|
||||
// //uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2;
|
||||
// //auto x_val1 = _mm256_fmadd_ps(scale1, trellis.gen8(val1, val3), dav);
|
||||
// //auto x_val2 = _mm256_fmadd_ps(scale2, trellis.gen8(val2, val4), dav);
|
||||
// //_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j, x_val1);
|
||||
// //_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j + QK_K/2, x_val2);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
//for (int j = 0; j < 64; ++j) {
|
||||
// _mm256_storeu_si256((__m256i *)y[j/8].qs+(j%8), trellis.next32(idx+8*j));
|
||||
//}
|
||||
//int shift1 = 8 - 4*(ib/4);
|
||||
//for (int j = 0; j < 4; ++j) {
|
||||
// for (int k = 0; k < 8; ++k) {
|
||||
// const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8);
|
||||
// const uint8_t * qh = ql + kNumGroups;
|
||||
// const uint32_t sh = x8[k][i].qs[ib] >> (8 + 6*j);
|
||||
// idx[k+0] = ql[8*ib+2*j+0] + ((qh[8*(ib%4)+2*j+0] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k];
|
||||
// idx[k+8] = ql[8*ib+2*j+1] + ((qh[8*(ib%4)+2*j+1] << shift1) & 0xf00) + ((sh & 56) << 9) + idx0[k];
|
||||
// }
|
||||
// _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+0, trellis.next32(idx+0));
|
||||
// _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+1, trellis.next32(idx+8));
|
||||
//}
|
||||
int shift1 = 8 - 4*(ib/4);
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
@@ -454,6 +402,92 @@ void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
constexpr int kNumGroups = 64;
|
||||
|
||||
Trellis3 trellis;
|
||||
|
||||
constexpr int k_acc = nrc_y;
|
||||
|
||||
__m256 accd[k_acc];
|
||||
const block_q8_2_x4 * y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
y[iy] = (const block_q8_2_x4 *)info.src1_row(iy);
|
||||
}
|
||||
|
||||
__m256i xv[8];
|
||||
|
||||
const block_iq4_kt * x8[8];
|
||||
float dkt[8];
|
||||
int32_t ls[8];
|
||||
uint32_t idx0[8], idx[8];
|
||||
|
||||
union { float f; uint32_t u; } bf16_helper;
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
const float * dptr = (const float *)((const char*)vx + (ix+k)*bx);
|
||||
dkt[k] = dptr[0];
|
||||
x8[k] = (const block_iq4_kt *)(dptr + 2);
|
||||
}
|
||||
auto vd = _mm256_loadu_ps(dkt);
|
||||
|
||||
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64;
|
||||
idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096;
|
||||
}
|
||||
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i *)ls)));
|
||||
auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-126.f));
|
||||
int shift1 = 8 - 4*(ib/4);
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8);
|
||||
const uint8_t * qh = ql + kNumGroups;
|
||||
const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j);
|
||||
idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k];
|
||||
}
|
||||
xv[j] = trellis.next32<true>(idx);
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const auto& yb = y[iy][2*i+ib/4];
|
||||
int i4 = ib%4;
|
||||
auto vy8 = _mm_loadu_si128((const __m128i *)yb.qs + 2*i4+0);
|
||||
auto vy = MM256_SET_M128I(vy8, vy8);
|
||||
auto sumi = _mm256_setzero_si256();
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[0], _mm256_shuffle_epi32(vy, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[1], _mm256_shuffle_epi32(vy, 0x50));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[2], _mm256_shuffle_epi32(vy, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[3], _mm256_shuffle_epi32(vy, 0xff));
|
||||
vy8 = _mm_loadu_si128((const __m128i *)yb.qs + 2*i4+1);
|
||||
vy = MM256_SET_M128I(vy8, vy8);
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[4], _mm256_shuffle_epi32(vy, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[5], _mm256_shuffle_epi32(vy, 0x50));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[6], _mm256_shuffle_epi32(vy, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, xv[7], _mm256_shuffle_epi32(vy, 0xff));
|
||||
bf16_helper.u = yb.d[i4] << 16;
|
||||
auto d8 = _mm256_mul_ps(scales, _mm256_set1_ps(bf16_helper.f));
|
||||
accd[iy] = _mm256_fmadd_ps(d8, _mm256_cvtepi32_ps(sumi), accd[iy]);
|
||||
bf16_helper.u = yb.d[i4+4] << 16;
|
||||
accd[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(bf16_helper.f), accd[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, accd[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) {
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
@@ -503,6 +537,112 @@ void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
constexpr int kNumGroups = 64;
|
||||
|
||||
Trellis3 trellis;
|
||||
|
||||
union { __m256i vec; uint32_t val[8]; } o_helper;
|
||||
|
||||
constexpr int k_acc = nrc_y;
|
||||
|
||||
__m256 accd[k_acc];
|
||||
const block_q8_2_x4 * y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
y[iy] = (const block_q8_2_x4 *)info.src1_row(iy);
|
||||
}
|
||||
|
||||
uint32_t values[64];
|
||||
__m256i xv[4], dot[4];
|
||||
__m256 scales[2];
|
||||
|
||||
auto sum_4 = [&dot] () {
|
||||
// dot[k] has 8 values from block k
|
||||
// 0 1 0 1 0 1 0 1
|
||||
dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[0], dot[1]), _mm256_unpackhi_epi32(dot[0], dot[1]));
|
||||
// 2 3 2 3 2 3 2 3
|
||||
dot[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[2], dot[3]), _mm256_unpackhi_epi32(dot[2], dot[3]));
|
||||
// 0 1 2 3 0 1 2 3
|
||||
dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(dot[0], dot[2]), _mm256_unpackhi_epi64(dot[0], dot[2]));
|
||||
return _mm256_cvtepi32_ps(dot[0]);
|
||||
};
|
||||
|
||||
auto compute_dot = [&dot, &xv] (const int8_t * y) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto yv = _mm256_loadu_si256((const __m256i *)y + k);
|
||||
dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv);
|
||||
}
|
||||
};
|
||||
|
||||
auto m126 = _mm256_set1_ps(-126.f);
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
auto d = _mm256_set1_ps(dptr[0]);
|
||||
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
|
||||
|
||||
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs);
|
||||
const uint32_t * shb = x[i].qs;
|
||||
const uint8_t * ql = (const uint8_t *)(shb + 8);
|
||||
const uint8_t * qh = ql + kNumGroups;
|
||||
auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1);
|
||||
iscales = _mm256_sub_epi32(iscales, _mm256_set1_epi32(64));
|
||||
auto all_scales = _mm256_mul_ps(d, _mm256_cvtepi32_ps(iscales));
|
||||
auto scales_l = _mm256_castps256_ps128(all_scales);
|
||||
auto scales_h = _mm256_extractf128_ps(all_scales, 1);
|
||||
scales[0] = _mm256_set_m128(scales_l, scales_l);
|
||||
scales[1] = _mm256_set_m128(scales_h, scales_h);
|
||||
o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
|
||||
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
|
||||
values[8*ib+2*j+ 0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
|
||||
values[8*ib+2*j+ 1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
|
||||
values[8*ib+2*j+32] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
|
||||
values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
|
||||
}
|
||||
}
|
||||
// sum[d4 * (x_i - 126) * d8 * y_i] => d4*d8*sum[x_i*y_i] - 126*d4*(d8*sum[y_i] -> m8)
|
||||
// d4*d8*sum[x_i*y_i] - 126*d4*m8
|
||||
for (int i128 = 0; i128 < 2; ++i128) {
|
||||
for (int k = 0; k < 4; ++k) xv[k] = trellis.next32<true>(values + 32*i128 + 8*k);
|
||||
//auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)y[0][2*i+i128].d)), 16));
|
||||
//auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy));
|
||||
//auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
|
||||
//m8 = _mm256_mul_ps(m8, _mm256_set1_ps(-126.f));
|
||||
//for (int k = 0; k < 4; ++k) {
|
||||
// xv[k] = trellis.next32<true>(values + 32*i128 + 8*k);
|
||||
// auto yv = _mm256_loadu_si256((const __m256i *)y[0][2*i+i128].qs + k);
|
||||
// dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv);
|
||||
//}
|
||||
//accd[0] = _mm256_fmadd_ps(_mm256_mul_ps(scales[i128], d8), sum_4(), accd[0]);
|
||||
//accd[0] = _mm256_fmadd_ps(scales[i128], m8, accd[0]);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const block_q8_2_x4& yb = y[iy][2*i+i128];
|
||||
auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16));
|
||||
dy = _mm256_mul_ps(scales[i128], dy);
|
||||
auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy));
|
||||
auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
|
||||
compute_dot(yb.qs);
|
||||
accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]);
|
||||
accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, hsum_float_8(accd[iy]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
@@ -585,11 +725,21 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
|
||||
bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
|
||||
|
||||
if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F32) {
|
||||
if (ne00%QK_K != 0) return false;
|
||||
|
||||
func16 = nullptr;
|
||||
|
||||
if (typeA == GGML_TYPE_IQ4_KT) {
|
||||
if (typeB == GGML_TYPE_Q8_2_X4) {
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_2_x4_T, kernels);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
func16 = nullptr;
|
||||
if (ggml_type(typeB) != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (typeA) {
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
|
||||
@@ -815,7 +815,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
case GGML_TYPE_IQ3_KT:
|
||||
case GGML_TYPE_IQ4_KT:
|
||||
return ggml_type(typeB) == GGML_TYPE_F32 ? iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16) : false;
|
||||
return iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16);
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
|
||||
Reference in New Issue
Block a user