iqk_mul_mat: add q8_0

It was actually ready but not turned on.
Having forgotten, I made a new implementation along the
lines of the fp16 implementation (i.e., using tiling).
That matched tiinyBLAS performance. But the existing
implementation that I now turned on is faster:
PP-512 = 134 t/s vs 128.3 t/s for tinyBLAS
TG-128 = 8.7 t/s vs 8.3 t/s for tinyBLAS (@ 4 threads)
This commit is contained in:
Iwan Kawrakow
2024-06-07 17:43:29 +03:00
parent 29164263f4
commit 74b711c8fd

View File

@@ -723,125 +723,6 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
};
//struct SimpleBitsAVX512 {
// __m512i values[4];
//};
//
//struct SignHelperAVX512 {
// inline void sign_2_values(const uint16_t * sign_bits, __m512i * values) const {
// const __mmask64 * mask = (const __mmask64 *)sign_bits;
// values[0] = _mm512_mask_sub_epi8(values[0], mask[0], _mm512_setzero_si512(), values[0]);
// values[1] = _mm512_mask_sub_epi8(values[1], mask[1], _mm512_setzero_si512(), values[1]);
// //auto minus = _mm512_set1_epi8(-1);
// //auto neg_value = _mm512_sub_epi8(_mm512_xor_si512(values[0], minus), minus);
// //values[0] = _mm512_mask_blend_epi8(mask[0], values[0], neg_value);
// //neg_value = _mm512_sub_epi8(_mm512_xor_si512(values[1], minus), minus);
// //values[1] = _mm512_mask_blend_epi8(mask[1], values[1], neg_value);
// }
//};
//
//struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
// DequantizerIQ3S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
//
// constexpr static int num_blocks = 8;
//
// inline __m128i make_scales(int i, float& dd) const {
// dd = GGML_FP16_TO_FP32(x[i].d);
// uint32_t aux32[2];
// std::memcpy(aux32, x[i].scales, 4);
// aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
// aux32[0] &= 0x0f0f0f0f;
// auto scales8 = _mm_shuffle_epi8(_mm_loadl_epi64((const __m128i *)aux32), _mm_set1_epi64x(0x0703060205010400));
// auto scales16 = _mm256_castsi256_si128(_mm256_cvtepi8_epi16(scales8));
// return _mm_or_si128(_mm_slli_epi16(scales16, 1), _mm_set1_epi16(1));
// }
// template <typename Q8>
// inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
// prepare(i);
// auto scales16 = make_scales(i, d);
// scb.accum_mins(scales16, q8, i, -minv*d, accd);
// auto scales256 = MM256_SET_M128I(scales16, scales16);
// auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
// scales[0] = _mm512_shuffle_epi8(all_scales, shuffles512[0]);
// scales[1] = _mm512_shuffle_epi8(all_scales, shuffles512[1]);
// }
//
// union index_t {
// __m512i vec;
// uint32_t val[16];
// };
//
// inline static __m512i make1(const uint8_t * qs, const uint8_t * qh, const __m512i& idx_shift, const __m512i& idx_mask) {
// auto idx_l = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)qs));
// auto idx_h = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_set1_epi32(qh[0])), _mm256_set1_epi32(qh[1]), 1);
// idx_h = _mm512_and_si512(_mm512_sllv_epi32(idx_h, idx_shift), idx_mask);
// index_t idx; idx.vec = _mm512_or_si512(idx_l, idx_h);
// return _mm512_set_epi32(iq3s_grid[idx.val[15]], iq3s_grid[idx.val[14]], iq3s_grid[idx.val[13]], iq3s_grid[idx.val[12]],
// iq3s_grid[idx.val[11]], iq3s_grid[idx.val[10]], iq3s_grid[idx.val[ 9]], iq3s_grid[idx.val[ 8]],
// iq3s_grid[idx.val[ 7]], iq3s_grid[idx.val[ 6]], iq3s_grid[idx.val[ 5]], iq3s_grid[idx.val[ 4]],
// iq3s_grid[idx.val[ 3]], iq3s_grid[idx.val[ 2]], iq3s_grid[idx.val[ 1]], iq3s_grid[idx.val[ 0]]);
// ////index_t idx1, idx2;
// ////auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs));
// ////auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask);
// ////idx1.vec = _mm256_or_si256(idx_h, idx_l);
// ////idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs + 8)));
// ////idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask);
// ////idx2.vec = _mm256_or_si256(idx_h, idx_l);
// ////return _mm512_set_epi32(iq3s_grid[idx2.val[7]], iq3s_grid[idx2.val[6]], iq3s_grid[idx2.val[5]], iq3s_grid[idx2.val[4]],
// //// iq3s_grid[idx2.val[3]], iq3s_grid[idx2.val[2]], iq3s_grid[idx2.val[1]], iq3s_grid[idx2.val[0]],
// //// iq3s_grid[idx1.val[7]], iq3s_grid[idx1.val[6]], iq3s_grid[idx1.val[5]], iq3s_grid[idx1.val[4]],
// //// iq3s_grid[idx1.val[3]], iq3s_grid[idx1.val[2]], iq3s_grid[idx1.val[1]], iq3s_grid[idx1.val[0]]);
// //////return _mm512_inserti32x8(value, val, 1);
// //index_t idx;
// //auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs));
// //auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask);
// //idx.vec = _mm256_or_si256(idx_h, idx_l);
// //auto val = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]],
// // iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]);
// //auto value = _mm512_inserti32x8(_mm512_setzero_si512(), val, 0);
// //idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs + 8)));
// //idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask);
// //idx.vec = _mm256_or_si256(idx_h, idx_l);
// //val = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]],
// // iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]);
// //return _mm512_inserti32x8(value, val, 1);
// }
//
// inline void prepare(int i) {
// prepare_unsigned(i);
// auto signs = (const uint16_t *)x[i].signs;
// sh.sign_2_values(signs+0, bits.values+0);
// sh.sign_2_values(signs+8, bits.values+2);
// auto min_value = _mm512_set1_epi8(minv);
// for (int k = 0; k < 4; ++k) bits.values[k] = _mm512_add_epi8(bits.values[k], min_value);
// }
//
// inline void prepare_unsigned(int i) {
// auto qs = x[i].qs;
// auto qh = x[i].qh;
// bits.values[0] = make1(qs+ 0, qh+0, idx_shift, idx_mask);
// bits.values[1] = make1(qs+16, qh+2, idx_shift, idx_mask);
// bits.values[2] = make1(qs+32, qh+4, idx_shift, idx_mask);
// bits.values[3] = make1(qs+48, qh+6, idx_shift, idx_mask);
// }
//
// constexpr static int minv = 16;
//
// SimpleBitsAVX512 bits;
// SignHelperAVX512 sh;
// Scales8KBase scb;
// const __m512i idx_shift = _mm512_set_epi32(1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8);
// const __m512i idx_mask = _mm512_set1_epi32(256);
// //const __m256i min_value = _mm256_set1_epi8(minv);
// const __m512i shuffles512[2] = {
// _mm512_set_epi64(0x0706070607060706, 0x0302030203020302, 0x0706070607060706, 0x0302030203020302,
// 0x0504050405040504, 0x0100010001000100, 0x0504050405040504, 0x0100010001000100),
// _mm512_set_epi64(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a,
// 0x0d0c0d0c0d0c0d0c, 0x0908090809080908, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
// };
//
//};
template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
@@ -2154,7 +2035,7 @@ struct Q_Unpacker {
struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {
Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
inline static int block_size() { return QK4_0; }
inline static int block_size() { return QK8_0; }
};
struct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_Dequantizer> {
Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
@@ -2173,22 +2054,6 @@ struct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_
inline static int block_size() { return QK4_1; }
};
template <int nrc_y>
void mul_mat_q8_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%Q8_0_Unpacker::block_size() == 0);
Q8<nrc_y, block_q8_0> q8(info);
int nb = n/Q8_0_Unpacker::block_size();
if (nb%4 == 0) {
mul_mat_qX_q8_Helper<Q8_0_Unpacker, Sum4_Q8, AccumType0<nrc_y, true>, ScaleHelperQ_0, block_q8_0, nrc_y>(
nb, vx, bx, info, q8.y, nrc_x
);
} else {
mul_mat_qX_q8_Helper<Q8_0_Unpacker, Sum4_Q8, AccumType0<nrc_y, false>, ScaleHelperQ_0, block_q8_0, nrc_y>(
nb, vx, bx, info, q8.y, nrc_x
);
}
}
template <int nrc> struct QF32 {
constexpr static int nrc_y = nrc;
QF32(const DataInfo& info) {
@@ -2332,8 +2197,75 @@ void mul_mat_f16_f32_T(int n, const void * vx, size_t bx, const DataInfo& info,
}
//#endif
template <int nrc> struct Q80 {
constexpr static int nrc_y = nrc;
Q80(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy);
}
IQK_ALWAYS_INLINE __m256i load1(int iy, int i) const { return _mm256_loadu_si256((const __m256i *)y[iy][i].qs); }
IQK_ALWAYS_INLINE float scale(int iy, int i) const { return GGML_FP16_TO_FP32(y[iy][i].d); }
const block_q8_0 * y[nrc_y];
};
inline __m256i mul_q80(__m256i x, __m256i y) {
auto ux = _mm256_sign_epi8(x, x);
#ifdef HAVE_FANCY_SIMD
return _mm256_dpbusd_epi32(_mm256_setzero_si256(), ux, _mm256_sign_epi8(y, x));
#else
return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(ux, _mm256_sign_epi8(y, x)));
#endif
}
template <int nrc_y>
void mul_mat_q80_q80_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK8_0 == 0);
constexpr int k_nx = 4;
int nb = n/QK8_0;
Q80<nrc_y> q8(info);
const block_q8_0 * x[k_nx];
float ds[k_nx];
__m256 acc[k_nx*nrc_y];
__m256i xv[k_nx];
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
int ix0 = k_nx*ix;
for (int kx = 0; kx < k_nx; ++kx) {
x[kx] = (const block_q8_0 *)((const char *)vx + (ix0 + kx)*bx);
ds[kx] = GGML_FP16_TO_FP32(x[kx][0].d);
xv[kx] = _mm256_loadu_si256((const __m256i *)x[kx][0].qs);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto yv = q8.load1(iy, 0);
float d = q8.scale(iy, 0);
for (int kx = 0; kx < k_nx; ++kx) {
auto dot = mul_q80(yv, xv[kx]);
acc[k_nx*iy + kx] = _mm256_mul_ps(_mm256_set1_ps(ds[kx]*d), _mm256_cvtepi32_ps(dot));
}
}
for (int i = 1; i < nb; ++i) {
for (int kx = 0; kx < k_nx; ++kx) {
ds[kx] = GGML_FP16_TO_FP32(x[kx][i].d);
xv[kx] = _mm256_loadu_si256((const __m256i *)x[kx][i].qs);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto yv = q8.load1(iy, i);
float d = q8.scale(iy, i);
for (int kx = 0; kx < k_nx; ++kx) {
auto dot = mul_q80(yv, xv[kx]);
acc[k_nx*iy + kx] = _mm256_fmadd_ps(_mm256_set1_ps(ds[kx]*d), _mm256_cvtepi32_ps(dot), acc[k_nx*iy + kx]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
for (int kx = 0; kx < k_nx; ++kx) info.store(ix0+kx, iy, hsum_float_8(acc[k_nx*iy+kx]));
}
}
int last_x = k_nx*(nrc_x/k_nx);
if (last_x == nrc_x) return;
// TODO: handle remaining rows
}
template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker>) {
if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||
std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
m.funcs[0] = mul_mat_qX_0_q8_0_T<Dequantizer, 1>;
m.funcs[1] = mul_mat_qX_0_q8_0_T<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_0_q8_0_T<Dequantizer, 3>;
@@ -2353,27 +2285,6 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
m.funcs[6] = mul_mat_qX_1_q8_1_T<Dequantizer, 7>;
m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>;
}
// else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S>) {
//#ifdef HAVE_FANCY_SIMD
// m.funcs[0] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 1>;
// m.funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;
// m.funcs[2] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 3>;
// m.funcs[3] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 4>;
// m.funcs[4] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 5>;
// m.funcs[5] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 6>;
// m.funcs[6] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 7>;
// m.funcs[7] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 8>;
//#else
// m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
// m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
// m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>;
// m.funcs[3] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 4>;
// m.funcs[4] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 5>;
// m.funcs[5] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 6>;
// m.funcs[6] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 7>;
// m.funcs[7] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 8>;
//#endif
// }
else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S> || std::is_same_v<Dequantizer, DequantizerIQ3XXS> ||
std::is_same_v<Dequantizer, DequantizerIQ2S> || std::is_same_v<Dequantizer, DequantizerIQ2XS> ||
std::is_same_v<Dequantizer, DequantizerIQ2XXS>) {
@@ -2440,6 +2351,19 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
row_size_q8 = ggml_row_size(GGML_TYPE_F32, ne00);
return true;
}
// Using the standard legacy quant template is slightly faster than tiling
// as implemented in mul_mat_q80_q80_T
// if (typeA == GGML_TYPE_Q8_0) {
// for (auto& f : mm.funcs) f = nullptr;
// mm.funcs[0] = mul_mat_q80_q80_T<1>;
// mm.funcs[1] = mul_mat_q80_q80_T<2>;
// mm.funcs[2] = mul_mat_q80_q80_T<3>;
//#ifdef __AVX512F__
// mm.funcs[3] = mul_mat_q80_q80_T<4>;
//#endif
// row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
// return true;
// }
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);
@@ -2508,6 +2432,11 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
MulMat::set_functions<Q5_1_Unpacker>(mm);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);
break;
case GGML_TYPE_Q8_0:
assert (ne00 % QK8_0 == 0);
MulMat::set_functions<Q8_0_Unpacker>(mm);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
break;
default:
return false;