mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-20 22:49:31 +00:00
iqk_mul_mat: improve iq1_bn (bitnet) on AVX2
We now get 207 t/s for PP-512 and 51 t/s for TG-128 using 16 threads.
This commit is contained in:
@@ -3724,6 +3724,22 @@ void quantize_row_q8_K64_reference(const float * restrict x, block_q8_K64 * rest
|
||||
assert(k % 64 == 0);
|
||||
const int64_t nb = k / 64;
|
||||
|
||||
// Check if a row-wise scale works. It almost does, PPL is only ~0.02 higher
|
||||
//float amax = 0;
|
||||
//for (int j = 0; j < k; ++j) {
|
||||
// float ax = fabsf(x[j]);
|
||||
// amax = MAX(ax, amax);
|
||||
//}
|
||||
|
||||
//float d = amax/127;
|
||||
//float id = d ? 1/d : 0.f;
|
||||
|
||||
//for (int i = 0; i < nb; i++) {
|
||||
// for (int j = 0; j < 64; ++j) y[i].qs[j] = nearest_int(id*x[j]);
|
||||
// y[i].d = d;
|
||||
// x += 64;
|
||||
//}
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
|
||||
float max = 0;
|
||||
|
||||
@@ -1370,24 +1370,24 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
|
||||
auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)],
|
||||
iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]);
|
||||
|
||||
auto v1_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1);
|
||||
auto v1_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1);
|
||||
auto v2_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1);
|
||||
auto v2_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1);
|
||||
auto v1 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1),
|
||||
_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1));
|
||||
auto v2 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1),
|
||||
_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1));
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto q8_1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), signs[0]);
|
||||
auto q8_2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), signs[1]);
|
||||
auto dot1 = _mm256_sub_epi8(_mm256_sign_epi8(q8_1, v1_m), _mm256_sign_epi8(q8_1, v1_p));
|
||||
auto dot2 = _mm256_sub_epi8(_mm256_sign_epi8(q8_2, v2_m), _mm256_sign_epi8(q8_2, v2_p));
|
||||
auto dot1 = _mm256_sign_epi8(q8_1, v1);
|
||||
auto dot2 = _mm256_sign_epi8(q8_2, v2);
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1);
|
||||
dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot2);
|
||||
auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2);
|
||||
#else
|
||||
dot1 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot1));
|
||||
dot2 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot2));
|
||||
auto dot = _mm256_add_epi32(_mm256_add_epi32(dot1, dot2));
|
||||
#endif
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_add_epi32(dot1, dot2)), accd[iy]);
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user