iq1_tn: NEON

This commit is contained in:
Iwan Kawrakow
2024-09-08 17:40:08 +02:00
parent c82bf200ce
commit 8d509a7d71

View File

@@ -2143,12 +2143,9 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
const auto m1_16 = _mm256_set1_epi16(1);
#endif
//const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
const block_iq1_bn * x;
const char * cx0 = (const char *)vx;
float scale;
//template <bool iq1_tn = is_iq1_tn, class = std::enable_if<iq1_tn>> float scale;
for (int ix = 0; ix < nrc_x; ++ix) {
@@ -5894,7 +5891,7 @@ struct DequantizerIQ1BN {
}
};
template <int nrc_y>
template <int nrc_y, bool is_iq1_tn>
static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
@@ -5904,11 +5901,17 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
int32x4_t accd[nrc_y];
int8x16x4_t v1, v2;
const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
float scale;
for (int ix = 0; ix < nrc_x; ++ix) {
x = (const block_iq1_bn *)((const char *)vx + ix*bx);
const char * cx = ((const char *)vx + ix*bx);
if constexpr (is_iq1_tn) {
scale = GGML_FP16_TO_FP32(*(const ggml_half *)cx);
cx += sizeof(ggml_half);
}
const block_iq1_bn * x = (const block_iq1_bn *)cx;
if constexpr (nrc_y == 1) {
int32x4_t acc[4] = {};
@@ -5962,7 +5965,11 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
if constexpr (is_iq1_tn) {
info.store(ix, iy, scale * vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
} else {
info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
}
}
}
@@ -6159,14 +6166,25 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
MulMat::set_functions<DequantizerIQ3S>(m);
break;
case GGML_TYPE_IQ1_BN:
m.funcs[0] = mul_mat_iq1bn_q8_K64<1>;
m.funcs[1] = mul_mat_iq1bn_q8_K64<2>;
m.funcs[2] = mul_mat_iq1bn_q8_K64<3>;
m.funcs[3] = mul_mat_iq1bn_q8_K64<4>;
m.funcs[4] = mul_mat_iq1bn_q8_K64<5>;
m.funcs[5] = mul_mat_iq1bn_q8_K64<6>;
m.funcs[6] = mul_mat_iq1bn_q8_K64<7>;
m.funcs[7] = mul_mat_iq1bn_q8_K64<8>;
m.funcs[0] = mul_mat_iq1bn_q8_K64<1, false>;
m.funcs[1] = mul_mat_iq1bn_q8_K64<2, false>;
m.funcs[2] = mul_mat_iq1bn_q8_K64<3, false>;
m.funcs[3] = mul_mat_iq1bn_q8_K64<4, false>;
m.funcs[4] = mul_mat_iq1bn_q8_K64<5, false>;
m.funcs[5] = mul_mat_iq1bn_q8_K64<6, false>;
m.funcs[6] = mul_mat_iq1bn_q8_K64<7, false>;
m.funcs[7] = mul_mat_iq1bn_q8_K64<8, false>;
expected_Btype = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ1_TN:
m.funcs[0] = mul_mat_iq1bn_q8_K64<1, true>;
m.funcs[1] = mul_mat_iq1bn_q8_K64<2, true>;
m.funcs[2] = mul_mat_iq1bn_q8_K64<3, true>;
m.funcs[3] = mul_mat_iq1bn_q8_K64<4, true>;
m.funcs[4] = mul_mat_iq1bn_q8_K64<5, true>;
m.funcs[5] = mul_mat_iq1bn_q8_K64<6, true>;
m.funcs[6] = mul_mat_iq1bn_q8_K64<7, true>;
m.funcs[7] = mul_mat_iq1bn_q8_K64<8, true>;
expected_Btype = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN: