mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-05 11:30:09 +00:00
iq3_kt is now working on NEON
This commit is contained in:
@@ -1664,6 +1664,9 @@ struct Trellis3 {
|
||||
result.val[i+0] = vaddq_s8(result.val[i+0], vpaddq_s8(s1_1, s2_1));
|
||||
result.val[i+2] = vaddq_s8(result.val[i+2], vpaddq_s8(s1_2, s2_2));
|
||||
}
|
||||
if constexpr (is_abs) {
|
||||
for (int i = 0; i < 4; ++i) result.val[i] = vabsq_s8(result.val[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
static uint8x16_t load_shuffle() {
|
||||
@@ -1879,6 +1882,69 @@ void iqk_dequantize_iq2_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
|
||||
}
|
||||
}
|
||||
|
||||
void iqk_dequantize_iq3_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
const int nb = n/QK_K;
|
||||
|
||||
Trellis3<true> trellis;
|
||||
|
||||
block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
|
||||
|
||||
const block_iq3_kt * x8[8];
|
||||
|
||||
float dkt[8];
|
||||
float ls[8], ls_all[64];
|
||||
uint32_t idx[8];
|
||||
uint32_t sign_bits[16];
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
const float * dptr = (const float *)((const char*)vx + (ix+k)*bx);
|
||||
dkt[k] = dptr[0] * 1.05f;
|
||||
x8[k] = (const block_iq3_kt *)(dptr + 1);
|
||||
}
|
||||
auto vd = vld1q_f32_x2(dkt);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
auto u32 = *(const uint32_t *)x8[k][i].scales;
|
||||
auto s8_u32 = uint32x2_t{u32, u32 >> 4};
|
||||
s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f));
|
||||
auto s16 = vmovl_s8(vreinterpret_s8_u32(s8_u32));
|
||||
vst1q_f32(ls_all + 8*k + 0, vcvtq_f32_s32(vmovl_s16(vget_low_s16(s16))));
|
||||
vst1q_f32(ls_all + 8*k + 4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16))));
|
||||
}
|
||||
auto mask = vdupq_n_u8(1);
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib];
|
||||
auto scales1 = vmulq_f32(vd.val[0], vld1q_f32(ls+0));
|
||||
auto scales2 = vmulq_f32(vd.val[1], vld1q_f32(ls+4));
|
||||
vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1));
|
||||
vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2));
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
const uint16_t * ql = (const uint16_t *)x8[k][i].ql;
|
||||
idx[k] = ql[4*ib+j] + 4096;
|
||||
auto qh = (const uint32_t *)x8[k][i].qh;
|
||||
sign_bits[k+0] = qh[2*j+0];
|
||||
sign_bits[k+8] = qh[2*j+1];
|
||||
}
|
||||
auto packed = trellis.next64(idx);
|
||||
auto signs = vld1q_u8_x4((const uint8_t *)sign_bits);
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
auto s = vorrq_u8(vceqq_u8(vandq_u8(signs.val[l], mask), mask), vdupq_n_u8(1));
|
||||
packed.val[l] = vmulq_s8(packed.val[l], vreinterpretq_s8_u8(s));
|
||||
}
|
||||
vst1q_s8_x4(y[ib].qs+64*j, packed);
|
||||
}
|
||||
mask = vshlq_n_u8(mask, 1);
|
||||
}
|
||||
y += 8; // = QK_K/32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
@@ -2158,10 +2224,10 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
|
||||
return true;
|
||||
}
|
||||
|
||||
bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, size_t stride_y, int nrc_x) {
|
||||
bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, [[maybe_unused]] size_t stride_y, int nrc_x) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt_q80_r8(n, vx, bx, y, nrc_x); break;
|
||||
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
|
||||
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt_q80_r8(n, vx, bx, y, nrc_x); break;
|
||||
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break;
|
||||
default: return false;
|
||||
}
|
||||
|
||||
@@ -272,7 +272,7 @@ struct MulMat {
|
||||
#else
|
||||
switch (type) {
|
||||
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
|
||||
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
|
||||
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
|
||||
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
|
||||
default: break;
|
||||
}
|
||||
@@ -435,7 +435,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
|
||||
return iqk_convert_1bit_q80_r8(typeA, n, vx, bx, vy, nrc_x);
|
||||
|
||||
default:
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user