46.0 t/s -> 162.0 t/s. iq3_s_r4 is at 79.4 t/s
This commit is contained in:
Iwan Kawrakow
2025-06-23 15:34:41 +02:00
parent 26965677e8
commit 548a5f3f0d
2 changed files with 103 additions and 1 deletions

View File

@@ -3555,6 +3555,107 @@ void iqk_convert_iq3_xxs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, i
}
}
//struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
// DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
//
// constexpr static int num_blocks() { return 8; }
// constexpr static bool should_scale_quants() { return false; }
//
// template <typename Q8>
// inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
// d = GGML_FP16_TO_FP32(x[i].d);
// uint32_t scales32[2];
// std::memcpy(scales32, x[i].scales, 4);
// scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
// scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
// auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7
// scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400)));
// auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8));
// int32x4x2_t scales;
// scales.val[0] = vmovl_s16(vget_low_s16(scales16));
// scales.val[1] = vmovl_s16(vget_high_s16(scales16));
// return scales;
// }
//
// static inline void make2(SignHelper& sh, const uint8x16_t& signs16, const uint16x8_t& idx_l, uint8_t qh,
// const int8x16_t& hshift, uint8x16_t * b) {
// auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256)));
// const uint16_t * idx = (const uint16_t *)&vindex;
// b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});
// b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});
// sh.apply_signs_1(b+0, signs16);
// sh.apply_signs_1(b+1, signs16);
// }
// static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh,
// const int8x16_t& hshift, uint8x16_t * b) {
// auto idx_l = vld1q_u8(qs);
// make2(sh, signs16, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0);
// make2(sh, signs16, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2);
// }
//
// inline void prepare(int i, int j) {
//
// static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
// const auto hshift = vld1q_s16(k_shift);
//
// const auto * qs = x[i].qs + 32*j;
// const auto * qh = x[i].qh + 4*j;
// const auto signs16 = vld1q_u8(x[i].signs + 16*j);
//
// sh.init();
// make4(sh, signs16, qs+ 0, qh+0, hshift, bits.b1.val);
// make4(sh, signs16, qs+16, qh+2, hshift, bits.b2.val);
// }
//
// SimpleBits bits;
// SignHelper sh;
// uint32x4x2_t gas;
//
//};
void iqk_convert_iq3_s_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
int nb = n/QK_K;
const block_iq3_s * x8[8];
block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
int8_t ls[16];
SignHelper sh;
uint32_t block[8];
int8x16x2_t xv[8];
static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
const auto hshift = vld1q_s16(k_shift);
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq3_s *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
float d = GGML_FP16_TO_FP32(x8[k][i].d);
for (int j = 0; j < 2; ++j) {
const auto * qs = x8[k][i].qs + 32*j;
const auto * qh = x8[k][i].qh + 4*j;
const auto signs16 = vld1q_u8(x8[k][i].signs + 16*j);
sh.init();
DequantizerIQ3S::make4(sh, signs16, qs+ 0, qh+0, hshift, (uint8x16_t *)&xv[4*j+0]);
DequantizerIQ3S::make4(sh, signs16, qs+16, qh+2, hshift, (uint8x16_t *)&xv[4*j+2]);
}
for (int ib32 = 0; ib32 < 8; ++ib32) {
ls[2*ib32 + 0] = ls[2*ib32 + 1] = (2*((x8[k][i].scales[ib32/2] >> 4*(ib32%2)) & 0xf) + 1);
}
float dnew = convert_to_q8_k_r8(1.f/127, xv, ls, block, (uint32_t *)y[i].qs + k);
y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
}
}
y += nb;
}
}
}
@@ -3565,7 +3666,7 @@ bool iqk_convert_iquants_q80_r8([[maybe_unused]] int type, int n, [[maybe_unused
case GGML_TYPE_IQ2_XS : iqk_convert_iq2_xs_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ2_S : iqk_convert_iq2_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ3_XXS: iqk_convert_iq3_xxs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_IQ3_S : iqk_convert_iq3_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ3_S : iqk_convert_iq3_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;

View File

@@ -275,6 +275,7 @@ struct MulMat {
case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_Q4_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q4_1 : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q5_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;