mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-03 02:20:01 +00:00
iq4_kt: NEON implementation
Have to use f32 arithmetic else I get gibberish? Correspondigly ridiculously slow.
This commit is contained in:
@@ -1618,7 +1618,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_kt_ref,
|
||||
.vec_dot = vec_dot_iq4_kt_q8_k,
|
||||
#ifdef __ARM_NEON
|
||||
.vec_dot_type = GGML_TYPE_F16,
|
||||
//.vec_dot_type = GGML_TYPE_F16,
|
||||
.vec_dot_type = GGML_TYPE_F32,
|
||||
#else
|
||||
.vec_dot_type = GGML_TYPE_F32,
|
||||
#endif
|
||||
|
||||
@@ -393,6 +393,10 @@ struct Trellis1 {
|
||||
}
|
||||
inline float16x8_t gen8(uint32_t val) const { return gen8(next8(val)); }
|
||||
inline float16x8_t gen8(uint32_t val1, uint32_t val2) const { return gen8(next8(val1, val2)); }
|
||||
inline float32x4x2_t gen8_f32(uint32_t val1, uint32_t val2) const {
|
||||
auto x16 = gen8(val1, val2);
|
||||
return { vcvt_f32_f16(vget_low_f16(x16)), vcvt_f32_f16(vget_high_f16(x16)) };
|
||||
}
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
@@ -604,10 +608,107 @@ static void mul_mat_iq4_kt_F16_T(int n, const void * vx, size_t bx, const DataIn
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
constexpr int kNumGroups = 64;
|
||||
|
||||
Trellis1 trellis;
|
||||
|
||||
float32x4_t accd[nrc_y * 2];
|
||||
const float * y[nrc_y];
|
||||
float row_sum[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
y[iy] = (const float *)info.src1_row(iy);
|
||||
auto sum = vdupq_n_f32(0);
|
||||
for (int i = 0; i < n/4; ++i) sum = vaddq_f32(sum, vld1q_f32(y[iy] + 4*i));
|
||||
row_sum[iy] = vaddvq_f32(sum);
|
||||
}
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
const float d = dptr[0] * 31.75f * 1.01f;
|
||||
const float row_av = dptr[1];
|
||||
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
|
||||
|
||||
for (int iy = 0; iy < nrc_y * 2; ++iy) {
|
||||
accd[iy] = vdupq_n_f32(0.0f);
|
||||
}
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const uint32_t * shb = x[i].qs;
|
||||
const uint8_t * ql = (const uint8_t *)(shb + 8);
|
||||
const uint8_t * qh = ql + kNumGroups;
|
||||
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
const float x_scale1 = (int)((shb[ib+0] & 0xff) >> 1) - 64;
|
||||
const float x_scale2 = (int)((shb[ib+4] & 0xff) >> 1) - 64;
|
||||
const float32x4_t scale1 = vdupq_n_f32(x_scale1);
|
||||
const float32x4_t scale2 = vdupq_n_f32(x_scale2);
|
||||
const uint32_t offset1 = 4096 + ((shb[ib+0] & 1) << 15);
|
||||
const uint32_t offset2 = 4096 + ((shb[ib+4] & 1) << 15);
|
||||
|
||||
uint32_t sh1 = shb[ib+0] >> 8;
|
||||
uint32_t sh2 = shb[ib+4] >> 8;
|
||||
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
//int j = 32*ib + 8*jj;
|
||||
// -> (j/8)%4 = (4*ib+jj)%4 = jj%4;
|
||||
// j/4 = 8*ib + 2*jj;
|
||||
//const uint32_t sh1 = shb[j/32+0] >> (8 + 6*((j/8)%4));
|
||||
//const uint32_t sh2 = shb[j/32+4] >> (8 + 6*((j/8)%4));
|
||||
|
||||
uint32_t val1 = ql[8*ib+2*jj+ 0] + ((qh[8*ib+2*jj+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1;
|
||||
uint32_t val2 = ql[8*ib+2*jj+32] + ((qh[8*ib+2*jj+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2;
|
||||
uint32_t val3 = ql[8*ib+2*jj+ 1] + ((qh[8*ib+2*jj+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1;
|
||||
uint32_t val4 = ql[8*ib+2*jj+33] + ((qh[8*ib+2*jj+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2;
|
||||
|
||||
sh1 >>= 6;
|
||||
sh2 >>= 6;
|
||||
|
||||
auto x1 = trellis.gen8_f32(val1, val3);
|
||||
auto x2 = trellis.gen8_f32(val2, val4);
|
||||
x1.val[0] = vmulq_f32(scale1, x1.val[0]);
|
||||
x1.val[1] = vmulq_f32(scale1, x1.val[1]);
|
||||
x2.val[0] = vmulq_f32(scale2, x2.val[0]);
|
||||
x2.val[1] = vmulq_f32(scale2, x2.val[1]);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y1 = vld1q_f32_x2(y[iy] + i*QK_K + 32*ib + 8*jj);
|
||||
auto y2 = vld1q_f32_x2(y[iy] + i*QK_K + 32*ib + 8*jj + 128);
|
||||
|
||||
accd[iy*2 + 0] = vfmaq_f32(accd[iy*2 + 0], y1.val[0], x1.val[0]);
|
||||
accd[iy*2 + 1] = vfmaq_f32(accd[iy*2 + 1], y1.val[1], x1.val[1]);
|
||||
accd[iy*2 + 0] = vfmaq_f32(accd[iy*2 + 0], y2.val[0], x2.val[0]);
|
||||
accd[iy*2 + 1] = vfmaq_f32(accd[iy*2 + 1], y2.val[1], x2.val[1]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
// Sum the two accumulators for this y row
|
||||
float32x4_t sum1 = vaddq_f32(accd[iy*2], accd[iy*2 + 1]);
|
||||
|
||||
// Compute final result
|
||||
float result = d*vaddvq_f32(sum1) + row_av*row_sum[iy];
|
||||
info.store(ix, iy, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
|
||||
|
||||
if (ne00%QK_K == 0 && ggml_type(typeB) == GGML_TYPE_F32 && ggml_type(typeA) == GGML_TYPE_IQ4_KT) {
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels);
|
||||
func16 = nullptr;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -654,7 +654,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
case GGML_TYPE_IQ3_KT:
|
||||
case GGML_TYPE_IQ4_KT:
|
||||
return ggml_type(typeB) == GGML_TYPE_F16 ? iqk_set_kernels_ktquants(ne00, typeA, typeB, m.funcs, m.func16) : false;
|
||||
return iqk_set_kernels_ktquants(ne00, typeA, typeB, m.funcs, m.func16);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user