mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
iq4_kt: slightly faster TG on NEON
This commit is contained in:
@@ -397,6 +397,10 @@ struct Trellis1 {
|
||||
auto x16 = gen8(val1, val2);
|
||||
return { vcvt_f32_f16(vget_low_f16(x16)), vcvt_f32_f16(vget_high_f16(x16)) };
|
||||
}
|
||||
inline float32x4x2_t gen8_f32(uint32_t val1, uint32_t val2, float16x8_t scale) const {
|
||||
auto x16 = vmulq_f16(gen8(val1, val2), scale);
|
||||
return { vcvt_f32_f16(vget_low_f16(x16)), vcvt_f32_f16(vget_high_f16(x16)) };
|
||||
}
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
@@ -640,10 +644,10 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
|
||||
const uint8_t * qh = ql + kNumGroups;
|
||||
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
const float x_scale1 = (int)((shb[ib+0] & 0xff) >> 1) - 64;
|
||||
const float x_scale2 = (int)((shb[ib+4] & 0xff) >> 1) - 64;
|
||||
const float32x4_t scale1 = vdupq_n_f32(x_scale1);
|
||||
const float32x4_t scale2 = vdupq_n_f32(x_scale2);
|
||||
const uint16_t x_scale1 = (int16_t)((shb[ib+0] & 0xff) >> 1) - 64;
|
||||
const uint16_t x_scale2 = (int16_t)((shb[ib+4] & 0xff) >> 1) - 64;
|
||||
const float16x8_t scale1 = vcvtq_f16_s16(vdupq_n_s16(x_scale1));
|
||||
const float16x8_t scale2 = vcvtq_f16_s16(vdupq_n_s16(x_scale2));
|
||||
const uint32_t offset1 = 4096 + ((shb[ib+0] & 1) << 15);
|
||||
const uint32_t offset2 = 4096 + ((shb[ib+4] & 1) << 15);
|
||||
|
||||
@@ -660,12 +664,8 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
|
||||
sh1 >>= 6;
|
||||
sh2 >>= 6;
|
||||
|
||||
auto x1 = trellis.gen8_f32(val1, val3);
|
||||
auto x2 = trellis.gen8_f32(val2, val4);
|
||||
x1.val[0] = vmulq_f32(scale1, x1.val[0]);
|
||||
x1.val[1] = vmulq_f32(scale1, x1.val[1]);
|
||||
x2.val[0] = vmulq_f32(scale2, x2.val[0]);
|
||||
x2.val[1] = vmulq_f32(scale2, x2.val[1]);
|
||||
auto x1 = trellis.gen8_f32(val1, val3, scale1);
|
||||
auto x2 = trellis.gen8_f32(val2, val4, scale2);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y1 = vld1q_f32_x2(y[iy] + i*QK_K + 32*ib + 8*j);
|
||||
|
||||
Reference in New Issue
Block a user