mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-08 07:20:12 +00:00
iqk_mul_mat: Arm implementation for iq2_xs (llama.cpp version)
We get 2.2X for PP-512 (52 t/s)
This commit is contained in:
@@ -2141,7 +2141,79 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
|
||||
float d;
|
||||
};
|
||||
|
||||
inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) {
|
||||
//constexpr static const uint64x2_t scale_shuffle = { 0x0b030a0209010800, 0x0f070e060d050c04 };
|
||||
//auto aux1 = vld1_u8(sc);
|
||||
//auto aux2 = vshr_n_u8(aux1, 4);
|
||||
//auto scales8 = vqtbl1q_u8(vandq_u8(vcombine_u8(aux1, aux2), vdupq_n_u8(0xf)), vreinterpretq_u8_u64(scale_shuffle));
|
||||
//scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(scales8, 1), vdupq_n_u8(1)));
|
||||
//
|
||||
auto aux = vld1_u8(sc);
|
||||
auto scales_l = vand_u8(aux, vdup_n_u8(0xf));
|
||||
auto scales_h = vshr_n_u8(aux, 4);
|
||||
auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
|
||||
|
||||
//auto scales_l = vld1_u8(sc);
|
||||
//auto scales_h = vshr_n_u8(scales_l, 4);
|
||||
//auto aux1 = vandq_u8(vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)), vdupq_n_u8(0xf));
|
||||
|
||||
auto scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1)));
|
||||
int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) };
|
||||
return make_wider(scales16);
|
||||
}
|
||||
|
||||
struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
|
||||
DequantizerIQ2XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
|
||||
d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
|
||||
return prepare_4bit_scales16(x[i].scales);
|
||||
}
|
||||
|
||||
inline static uint8x16_t make1(const uint16_t * qs) {
|
||||
auto b = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_u8((const uint8_t *)(iq2xs_grid + (qs[1] & 511))));
|
||||
auto s = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9))));
|
||||
return vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b), s));
|
||||
}
|
||||
|
||||
inline static void make4(const uint16_t * qs, uint8x16_t * b) {
|
||||
b[0] = make1(qs + 0);
|
||||
b[1] = make1(qs + 2);
|
||||
b[2] = make1(qs + 4);
|
||||
b[3] = make1(qs + 6);
|
||||
//auto bits = vld1q_u16(qs);
|
||||
//auto vidx = vandq_u16(bits, vdupq_n_u16(511));
|
||||
//const uint16_t * idx = (const uint16_t *)&vidx;
|
||||
//b[0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + idx[0])), vld1_u8((const uint8_t *)(iq2xs_grid + idx[1])));
|
||||
//b[1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + idx[2])), vld1_u8((const uint8_t *)(iq2xs_grid + idx[3])));
|
||||
//b[2] = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + idx[4])), vld1_u8((const uint8_t *)(iq2xs_grid + idx[5])));
|
||||
//b[3] = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + idx[6])), vld1_u8((const uint8_t *)(iq2xs_grid + idx[7])));
|
||||
//vidx = vshrq_n_u16(bits, 9);
|
||||
//b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]),
|
||||
// vcombine_s8(vld1_s8((const int8_t *)(keven_signs + idx[0])), vld1_s8((const int8_t *)(keven_signs + idx[1])))));
|
||||
//b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]),
|
||||
// vcombine_s8(vld1_s8((const int8_t *)(keven_signs + idx[2])), vld1_s8((const int8_t *)(keven_signs + idx[3])))));
|
||||
//b[2] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[2]),
|
||||
// vcombine_s8(vld1_s8((const int8_t *)(keven_signs + idx[4])), vld1_s8((const int8_t *)(keven_signs + idx[5])))));
|
||||
//b[3] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[3]),
|
||||
// vcombine_s8(vld1_s8((const int8_t *)(keven_signs + idx[6])), vld1_s8((const int8_t *)(keven_signs + idx[7])))));
|
||||
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
make4(x[i].qs + 16*j + 0, bits.b1.val);
|
||||
make4(x[i].qs + 16*j + 8, bits.b2.val);
|
||||
}
|
||||
|
||||
SimpleBits bits;
|
||||
|
||||
float d;
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
|
||||
DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
@@ -2152,13 +2224,14 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
|
||||
d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
|
||||
return prepare_4bit_scales16(x[i].scales);
|
||||
|
||||
auto aux1 = vld1_u8(x[i].scales);
|
||||
auto aux2 = vshr_n_u8(aux1, 4);
|
||||
auto scales8 = vqtbl1q_u8(vandq_u8(vcombine_u8(aux1, aux2), vdupq_n_u8(0xf)), vreinterpretq_u8_u64(scale_shuffle));
|
||||
scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(scales8, 1), vdupq_n_u8(1)));
|
||||
int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) };
|
||||
return make_wider(scales16);
|
||||
//auto aux1 = vld1_u8(x[i].scales);
|
||||
//auto aux2 = vshr_n_u8(aux1, 4);
|
||||
//auto scales8 = vqtbl1q_u8(vandq_u8(vcombine_u8(aux1, aux2), vdupq_n_u8(0xf)), vreinterpretq_u8_u64(scale_shuffle));
|
||||
//scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(scales8, 1), vdupq_n_u8(1)));
|
||||
//int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) };
|
||||
//return make_wider(scales16);
|
||||
}
|
||||
|
||||
static inline void make4(const uint8x16_t& signs16, uint8x16_t& shuffle, const uint8_t * qs, const uint8_t * qh,
|
||||
@@ -2756,6 +2829,9 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
MulMat::set_functions<DequantizerIQ2XXS>(m);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
MulMat::set_functions<DequantizerIQ2XS>(m);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
MulMat::set_functions<DequantizerIQ2S>(m);
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user