mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
iqk_mul_mat(iq1_bn): WIP NEON (not working)
This commit is contained in:
100
iqk_mul_mat.cpp
100
iqk_mul_mat.cpp
@@ -4011,6 +4011,95 @@ void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info,
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc> struct Q8_K64 {
|
||||
|
||||
constexpr static int nrc_y = nrc;
|
||||
|
||||
Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_K64 *)info.src1_row(iy); }
|
||||
|
||||
inline int8x16x4_t load_quants(int iy, int i) const { return vld1q_s8_x4(y[iy][i].qs); }
|
||||
inline float scale(int iy, int i) const { return y[iy][i].d; }
|
||||
|
||||
const block_q8_K64 * y[nrc_y];
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
const int nb = n / QK_IQ1BN;
|
||||
Q8_K64<nrc_y> q8(info);
|
||||
float32x4_t accd[nrc_y];
|
||||
int8x16x4_t signs;
|
||||
|
||||
uint64x2x4_t aux;
|
||||
uint8x16x4_t vp, vm;
|
||||
|
||||
const auto m1 = vdupq_n_u8(1);
|
||||
uint8x16x4_t sign_shuffles;
|
||||
sign_shuffles.val[0] = vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0101010101010101});
|
||||
sign_shuffles.val[1] = vreinterpretq_u8_u64(uint64x2_t{0x0202020202020202, 0x0303030303030303});
|
||||
sign_shuffles.val[2] = vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505});
|
||||
sign_shuffles.val[3] = vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707});
|
||||
const auto shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0808080808080808});
|
||||
const auto shuff2 = vaddq_u8(shuff1, m1);
|
||||
const auto mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
|
||||
|
||||
const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
|
||||
typedef union { float f; uint32_t i; } scale_t;
|
||||
|
||||
scale_t scale;
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
x = (const block_iq1_bn *)((const char *)vx + ix*bx);
|
||||
uint16_t u = x[0].extra & 0xff;
|
||||
scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_f32(0.f);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
auto all_signs = vdupq_n_u8(x[i].extra >> 8);
|
||||
all_signs = vorrq_u8(vceqq_u8(vandq_u8(all_signs, mask1), mask1), m1);
|
||||
signs.val[0] = vqtbl1q_u8(all_signs, sign_shuffles.val[0]);
|
||||
signs.val[1] = vqtbl1q_u8(all_signs, sign_shuffles.val[1]);
|
||||
signs.val[2] = vqtbl1q_u8(all_signs, sign_shuffles.val[2]);
|
||||
signs.val[3] = vqtbl1q_u8(all_signs, sign_shuffles.val[3]);
|
||||
|
||||
auto ql = x[i].ql;
|
||||
auto qh = x[i].qh;
|
||||
aux.val[0] = uint64x2_t{iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)]};
|
||||
aux.val[1] = uint64x2_t{iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)]};
|
||||
aux.val[2] = uint64x2_t{iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)]};
|
||||
aux.val[3] = uint64x2_t{iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)]};
|
||||
|
||||
vp.val[0] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[0], shuff1), mask1), mask1);
|
||||
vp.val[1] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[1], shuff1), mask1), mask1);
|
||||
vp.val[2] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[2], shuff1), mask1), mask1);
|
||||
vp.val[3] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[3], shuff1), mask1), mask1);
|
||||
vm.val[0] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[0], shuff2), mask1), mask1);
|
||||
vm.val[1] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[1], shuff2), mask1), mask1);
|
||||
vm.val[2] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[2], shuff2), mask1), mask1);
|
||||
vm.val[3] = vceqq_u8(vandq_u8(vqtbl1q_u8(aux.val[3], shuff2), mask1), mask1);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto q = q8.load_quants(iy, i);
|
||||
int32x4_t sumi = vdupq_n_s32(0);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
auto tmp = vmulq_s8(q.val[j], signs.val[j]);
|
||||
tmp = vsubq_s8(vmulq_s8(q.val[j], vm.val[j]), vmulq_s8(q.val[j], vp.val[j]));
|
||||
sumi = ggml_vdotq_s32(sumi, m1, tmp);
|
||||
}
|
||||
accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi));
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, scale.f * vaddvq_f32(accd[iy]));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> ||
|
||||
std::is_same_v<Dequantizer, DequantizerQ80>) {
|
||||
@@ -4094,6 +4183,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
case GGML_TYPE_IQ3_S:
|
||||
MulMat::set_functions<DequantizerIQ3S>(m);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_BN:
|
||||
m.funcs[0] = mul_mat_iq1bn_q8_K64<1>;
|
||||
m.funcs[1] = mul_mat_iq1bn_q8_K64<2>;
|
||||
m.funcs[2] = mul_mat_iq1bn_q8_K64<3>;
|
||||
m.funcs[3] = mul_mat_iq1bn_q8_K64<4>;
|
||||
m.funcs[4] = mul_mat_iq1bn_q8_K64<5>;
|
||||
m.funcs[5] = mul_mat_iq1bn_q8_K64<6>;
|
||||
m.funcs[6] = mul_mat_iq1bn_q8_K64<7>;
|
||||
m.funcs[7] = mul_mat_iq1bn_q8_K64<8>;
|
||||
expected_Btype = GGML_TYPE_Q8_K64;
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
MulMat::set_functions<DequantizerQ40>(m);
|
||||
expected_Btype = GGML_TYPE_Q8_0;
|
||||
|
||||
Reference in New Issue
Block a user