mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
iq2_tn: AVX512
With this tweak we get to PP-512 = 431 t/s.
This commit is contained in:
@@ -695,11 +695,11 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
|
||||
struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
|
||||
DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
|
||||
inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, [[maybe_unused]] __m512i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
bits.prepare(x[i].qs);
|
||||
process_mins_16(_mm256_set1_epi16(1), q8, i, -d, accm);
|
||||
scales[0] = scales[1] = _mm512_set1_epi16(1);
|
||||
//process_mins_16(_mm256_set1_epi16(1), q8, i, -d, accm);
|
||||
//scales[0] = scales[1] = _mm512_set1_epi16(1);
|
||||
}
|
||||
Q2Bits bits;
|
||||
};
|
||||
@@ -997,14 +997,22 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da
|
||||
deq.new_block(i, q8, accm, scales);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
//compute_block(iy, i, deq.d, q8, deq.bits.values, scales, accd);
|
||||
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
|
||||
const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
|
||||
const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
|
||||
const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));
|
||||
auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
|
||||
sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
|
||||
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
|
||||
auto sumi_scales = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i));
|
||||
auto sumi = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(
|
||||
_mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales, 0),
|
||||
deq.bits.values[0], q8.load_quants64(iy, i, 0)), deq.bits.values[1], q8.load_quants64(iy, i, 1)),
|
||||
deq.bits.values[2], q8.load_quants64(iy, i, 2)), deq.bits.values[3], q8.load_quants64(iy, i, 3));
|
||||
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
|
||||
} else {
|
||||
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
|
||||
const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
|
||||
const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
|
||||
const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));
|
||||
auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
|
||||
sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
|
||||
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user