mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-06 20:10:08 +00:00
Much faster iq2_xxs GEMM
PP-512 = 290 t/s vs ~110 t/s (iq2_xxs) or 148 t/s (iq2_xxs_r4) on main.
This commit is contained in:
@@ -1067,7 +1067,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.from_float = quantize_row_iq2_xxs,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_iq2_xxs_ref,
|
||||
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_2_X4, //GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
.row_meta_size = 0,
|
||||
},
|
||||
|
||||
@@ -172,7 +172,6 @@ static inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {
|
||||
aux32[0] = a0 & 0x3f3f3f3f;
|
||||
}
|
||||
|
||||
#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__)
|
||||
const uint64_t keven_signs[128] = {
|
||||
0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,
|
||||
0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,
|
||||
@@ -207,7 +206,6 @@ const uint64_t keven_signs[128] = {
|
||||
0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,
|
||||
0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef __AVX2__
|
||||
|
||||
@@ -525,6 +523,24 @@ struct Q4Bits {
|
||||
|
||||
#endif
|
||||
|
||||
inline void iqk_transpose_8x8(__m256 * m) {
|
||||
for (int k = 0; k < 8; k += 4) {
|
||||
auto t0 = _mm256_unpacklo_ps(m[k+0], m[k+1]);
|
||||
auto t1 = _mm256_unpacklo_ps(m[k+2], m[k+3]);
|
||||
auto t2 = _mm256_unpackhi_ps(m[k+0], m[k+1]);
|
||||
auto t3 = _mm256_unpackhi_ps(m[k+2], m[k+3]);
|
||||
m[k+0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
|
||||
m[k+1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
|
||||
m[k+2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
|
||||
m[k+3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
|
||||
}
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto t = _mm256_set_m128(_mm256_extractf128_ps(m[k+4], 1), _mm256_extractf128_ps(m[k], 1));
|
||||
m[k+0] = _mm256_set_m128(_mm256_castps256_ps128(m[k+4]), _mm256_castps256_ps128(m[k+0]));
|
||||
m[k+4] = t;
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
// ------------------------------------ __aarch64__ --------------------------------------------------
|
||||
|
||||
|
||||
@@ -87,13 +87,12 @@ struct EvenSignHelper {
|
||||
const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0);
|
||||
const __m256i mask = _mm256_set1_epi32(127);
|
||||
const __m256i mone = _mm256_set1_epi32(1);
|
||||
#else
|
||||
#endif
|
||||
inline void sign_value(uint32_t aux32, __m256i& value) const {
|
||||
auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127],
|
||||
keven_signs[(aux32 >> 7) & 127], keven_signs[(aux32 >> 0) & 127]);
|
||||
value = _mm256_sign_epi8(value, signs);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
struct SignHelper {
|
||||
@@ -1560,6 +1559,55 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
|
||||
}
|
||||
}
|
||||
|
||||
void iqk_convert_iq2_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
|
||||
int nb = n/QK_K;
|
||||
|
||||
const block_iq2_xxs * x8[8];
|
||||
|
||||
block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
|
||||
|
||||
ggml_half dh[8];
|
||||
uint16_t all_ls[64];
|
||||
EvenSignHelper esh;
|
||||
|
||||
uint32_t block[8];
|
||||
uint32_t aux32[2];
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_xxs *)((const char *)vx + (ix + k)*bx);
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
// TODO: simdify
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
dh[k] = x8[k][i].d;
|
||||
for (int ib32 = 0; ib32 < 8; ++ib32) {
|
||||
std::memcpy(aux32, x8[k][i].qs + 4*ib32, 2*sizeof(uint32_t));
|
||||
all_ls[8*ib32 + k] = (2*(aux32[1] >> 28) + 1);
|
||||
auto value = _mm256_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
|
||||
esh.sign_value(aux32[1], value);
|
||||
_mm256_storeu_si256((__m256i *)block, value);
|
||||
auto qs = (uint32_t *)y[ib32].qs;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
qs[8*l + k + 0] = block[l + 0];
|
||||
qs[8*l + k + 32] = block[l + 4];
|
||||
}
|
||||
}
|
||||
}
|
||||
auto vd = _mm256_mul_ps(_mm256_set1_ps(0.125f), _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh)));
|
||||
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
||||
auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
|
||||
auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
|
||||
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
|
||||
_mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
|
||||
}
|
||||
y += QK_K/32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
|
||||
funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
|
||||
funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
|
||||
@@ -1629,6 +1677,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
|
||||
|
||||
}
|
||||
|
||||
bool iqk_convert_iquants_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
|
||||
if (n%QK_K != 0 || nrc_x%8 != 0) return false;
|
||||
switch (ggml_type(type)) {
|
||||
case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_0_r8(n, vx, bx, vy, nrc_x); break;
|
||||
default: return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#else
|
||||
// --------------------------------------- __aarch64__ ---------------------------------------------
|
||||
|
||||
|
||||
@@ -8,4 +8,6 @@
|
||||
|
||||
bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
|
||||
|
||||
bool iqk_convert_iquants_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1615,6 +1615,84 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn
|
||||
}
|
||||
#endif
|
||||
|
||||
//template <int nrc_y>
|
||||
//static void mul_mat_q8_0_r8_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
// GGML_ASSERT(nrc_x%8 == 0);
|
||||
// Q8<nrc_y, block_q8_K> q8(info);
|
||||
// auto m1 = _mm256_set1_epi16(1);
|
||||
// int nb = n / QK_K;
|
||||
// __m256 acc[nrc_y] = {};
|
||||
// float d8[4*nrc_y];
|
||||
// __m256i qx[4], sx[4];
|
||||
// auto dot = [&qx, &sx, &m1] (const int8_t * qy) {
|
||||
// auto y128 = _mm_loadu_si128((const __m128i*)qy);
|
||||
// auto y = MM256_SET_M128I(y128, y128);
|
||||
// auto sumi1 = _mm256_add_epi32(
|
||||
// _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]))),
|
||||
// _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])))
|
||||
// );
|
||||
// auto sumi2 = _mm256_add_epi32(
|
||||
// _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]))),
|
||||
// _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])))
|
||||
// );
|
||||
// return _mm256_add_epi32(sumi1, sumi2);
|
||||
// };
|
||||
// for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
// const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
|
||||
// for (int i = 0; i < nb; ++i) {
|
||||
// for (int ib = 0; ib < 4; ++ib) {
|
||||
// auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[8*i+ib+0].d));
|
||||
// auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[8*i+ib+4].d));
|
||||
// for (int j = 0; j < 4; ++j) {
|
||||
// qx[j] = _mm256_loadu_si256((const __m256i *)iq8[8*i+ib+0].qs+j);
|
||||
// sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
|
||||
// }
|
||||
// for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
// auto sumi = dot(q8.y[iy][].qs+32*k);
|
||||
// auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
|
||||
// acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
// }
|
||||
// for (int j = 0; j < 4; ++j) {
|
||||
// qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4+j);
|
||||
// sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
|
||||
// }
|
||||
// for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
// auto sumi = dot(q8.y[iy][ib4].qs+32*k+16);
|
||||
// auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
|
||||
// acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// for (int ib = 4*(nb/4); ib < nb; ++ib) {
|
||||
// auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d));
|
||||
// for (int j = 0; j < 4; ++j) {
|
||||
// qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+j);
|
||||
// sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
|
||||
// }
|
||||
// for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
// auto qy = (const block_q8_2 *)q8.y[iy];
|
||||
// auto sumi = dot(qy[ib].qs);
|
||||
// auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(ggml_bf16_t{qy[ib].d})));
|
||||
// acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
// }
|
||||
// for (int j = 0; j < 4; ++j) {
|
||||
// qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+4+j);
|
||||
// sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
|
||||
// }
|
||||
// for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
// auto qy = (const block_q8_2 *)q8.y[iy];
|
||||
// auto sumi = dot(qy[ib].qs+16);
|
||||
// auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(ggml_bf16_t{qy[ib].d})));
|
||||
// acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
// }
|
||||
// }
|
||||
// for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
// info.store(ix, iy, acc[iy]);
|
||||
// acc[iy] = _mm256_setzero_ps();
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
|
||||
if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||
|
||||
std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
|
||||
|
||||
@@ -239,6 +239,7 @@ struct MulMat {
|
||||
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
|
||||
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
|
||||
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
|
||||
case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
|
||||
default: break;
|
||||
}
|
||||
#else
|
||||
@@ -327,6 +328,89 @@ static std::vector<char> & thread_local_work_buffer() {
|
||||
return f;
|
||||
}
|
||||
|
||||
bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, size_t stride_y, int nrc_x) {
|
||||
|
||||
switch (typeA) {
|
||||
//case GGML_TYPE_F16:
|
||||
//case GGML_TYPE_F32:
|
||||
//case GGML_TYPE_BF16:
|
||||
//case GGML_TYPE_BF16_R16:
|
||||
// return iqk_set_kernels_float(ne00, typeA, typeB, mm.funcs);
|
||||
//case GGML_TYPE_Q2_K:
|
||||
//case GGML_TYPE_Q3_K:
|
||||
//case GGML_TYPE_Q4_K:
|
||||
//case GGML_TYPE_Q5_K:
|
||||
//case GGML_TYPE_Q6_K:
|
||||
//case GGML_TYPE_IQ4_XS:
|
||||
//case GGML_TYPE_Q2_K_R4:
|
||||
//case GGML_TYPE_Q3_K_R4:
|
||||
//case GGML_TYPE_Q4_K_R4:
|
||||
//case GGML_TYPE_Q5_K_R4:
|
||||
//case GGML_TYPE_Q6_K_R4:
|
||||
//case GGML_TYPE_IQ4_XS_R8:
|
||||
//case GGML_TYPE_Q8_K_R8:
|
||||
//case GGML_TYPE_Q8_KV:
|
||||
//case GGML_TYPE_Q8_KV_R8:
|
||||
// return iqk_set_kernels_kquants(ne00, typeA, typeB, mm.funcs, mm.func16);
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_XXS_R4:
|
||||
case GGML_TYPE_IQ2_XS_R4:
|
||||
case GGML_TYPE_IQ2_S_R4:
|
||||
case GGML_TYPE_IQ3_XXS_R4:
|
||||
case GGML_TYPE_IQ3_S_R4:
|
||||
return iqk_convert_iquants_q80_r8(typeA, n, vx, bx, vy, nrc_x);
|
||||
//case GGML_TYPE_IQ4_KS:
|
||||
//case GGML_TYPE_IQ5_KS:
|
||||
//case GGML_TYPE_IQ4_KSS:
|
||||
//case GGML_TYPE_IQ2_K:
|
||||
//case GGML_TYPE_IQ2_KS:
|
||||
//case GGML_TYPE_IQ3_K:
|
||||
//case GGML_TYPE_IQ4_K:
|
||||
//case GGML_TYPE_IQ5_K:
|
||||
//case GGML_TYPE_IQ6_K:
|
||||
//case GGML_TYPE_IQ2_K_R4:
|
||||
//case GGML_TYPE_IQ3_K_R4:
|
||||
//case GGML_TYPE_IQ4_K_R4:
|
||||
//case GGML_TYPE_IQ5_K_R4:
|
||||
//case GGML_TYPE_IQ4_KS_R4:
|
||||
//case GGML_TYPE_IQ5_KS_R4:
|
||||
// return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
case GGML_TYPE_IQ3_KT:
|
||||
case GGML_TYPE_IQ4_KT:
|
||||
return iqk_dequantize_ktquants(typeA, n, vx, bx, vy, stride_y, nrc_x);
|
||||
//case GGML_TYPE_Q4_0:
|
||||
//case GGML_TYPE_Q4_1:
|
||||
//case GGML_TYPE_Q5_0:
|
||||
//case GGML_TYPE_Q5_1:
|
||||
//case GGML_TYPE_Q6_0:
|
||||
//case GGML_TYPE_Q8_0:
|
||||
//case GGML_TYPE_IQ4_NL:
|
||||
//case GGML_TYPE_Q4_0_R8:
|
||||
//case GGML_TYPE_Q5_0_R4:
|
||||
//case GGML_TYPE_Q6_0_R4:
|
||||
//case GGML_TYPE_Q8_0_R8:
|
||||
//case GGML_TYPE_IQ4_NL_R4:
|
||||
// return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
|
||||
//case GGML_TYPE_IQ1_S:
|
||||
//case GGML_TYPE_IQ1_S_R4:
|
||||
//case GGML_TYPE_IQ1_M_R4:
|
||||
//case GGML_TYPE_IQ1_BN:
|
||||
//case GGML_TYPE_IQ2_BN:
|
||||
//case GGML_TYPE_IQ2_BN_R4:
|
||||
// return iqk_set_kernels_1bit(ne00, typeA, typeB, mm.funcs, mm.func16);
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
|
||||
@@ -352,9 +436,7 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
|
||||
first_x *= num_rows;
|
||||
nrc_x *= num_rows;
|
||||
|
||||
auto type_size = ggml_type_size(dequant_type);
|
||||
|
||||
size_t row_size_qx = ne00*type_size;
|
||||
size_t row_size_qx = ggml_row_size(dequant_type, ne00);
|
||||
size_t row_size_qy = strideB;
|
||||
|
||||
//printf("Dequant mul mat %s x %s: ne00 = %d, row_size = %d\n", ggml_type_name(dequant_type), ggml_type_name(ggml_type(typeB)), (int)ne00, (int)row_size_qx);
|
||||
@@ -368,7 +450,7 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
|
||||
this_info.s += ix;
|
||||
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
||||
if (f.size() < row_size_qx*this_nrc_x) f.resize(row_size_qx*this_nrc_x);
|
||||
if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f.data(), ne00, this_nrc_x)) {
|
||||
if (!iqk_convert_repack(typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f.data(), ne00, this_nrc_x)) {
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
mm.mul_mat_NxM(ne00, f.data(), row_size_qx, this_info, this_nrc_x, Ny);
|
||||
|
||||
Reference in New Issue
Block a user