mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-08 07:20:12 +00:00
iqk_mul_mat: Arm implementation for iq2_s (llama.cpp version)
We get only a 2.07X for PP-512 to get up to 31 t/s, so iq2_s remains slow.
This commit is contained in:
149
iqk_mul_mat.cpp
149
iqk_mul_mat.cpp
@@ -2016,6 +2016,78 @@ const uint64_t keven_signs[128] = {
|
||||
0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,
|
||||
};
|
||||
|
||||
struct SimpleBits {
|
||||
uint8x16x4_t b1;
|
||||
uint8x16x4_t b2;
|
||||
};
|
||||
|
||||
const uint64_t kall_signs[256] = {
|
||||
0x0101010101010101, 0x01010101010101ff, 0x010101010101ff01, 0x010101010101ffff,
|
||||
0x0101010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0x0101010101ffffff,
|
||||
0x01010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0x01010101ff01ffff,
|
||||
0x01010101ffff0101, 0x01010101ffff01ff, 0x01010101ffffff01, 0x01010101ffffffff,
|
||||
0x010101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0x010101ff0101ffff,
|
||||
0x010101ff01ff0101, 0x010101ff01ff01ff, 0x010101ff01ffff01, 0x010101ff01ffffff,
|
||||
0x010101ffff010101, 0x010101ffff0101ff, 0x010101ffff01ff01, 0x010101ffff01ffff,
|
||||
0x010101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0x010101ffffffffff,
|
||||
0x0101ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0x0101ff010101ffff,
|
||||
0x0101ff0101ff0101, 0x0101ff0101ff01ff, 0x0101ff0101ffff01, 0x0101ff0101ffffff,
|
||||
0x0101ff01ff010101, 0x0101ff01ff0101ff, 0x0101ff01ff01ff01, 0x0101ff01ff01ffff,
|
||||
0x0101ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0x0101ff01ffffffff,
|
||||
0x0101ffff01010101, 0x0101ffff010101ff, 0x0101ffff0101ff01, 0x0101ffff0101ffff,
|
||||
0x0101ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0x0101ffff01ffffff,
|
||||
0x0101ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0x0101ffffff01ffff,
|
||||
0x0101ffffffff0101, 0x0101ffffffff01ff, 0x0101ffffffffff01, 0x0101ffffffffffff,
|
||||
0x01ff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0x01ff01010101ffff,
|
||||
0x01ff010101ff0101, 0x01ff010101ff01ff, 0x01ff010101ffff01, 0x01ff010101ffffff,
|
||||
0x01ff0101ff010101, 0x01ff0101ff0101ff, 0x01ff0101ff01ff01, 0x01ff0101ff01ffff,
|
||||
0x01ff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0x01ff0101ffffffff,
|
||||
0x01ff01ff01010101, 0x01ff01ff010101ff, 0x01ff01ff0101ff01, 0x01ff01ff0101ffff,
|
||||
0x01ff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0x01ff01ff01ffffff,
|
||||
0x01ff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0x01ff01ffff01ffff,
|
||||
0x01ff01ffffff0101, 0x01ff01ffffff01ff, 0x01ff01ffffffff01, 0x01ff01ffffffffff,
|
||||
0x01ffff0101010101, 0x01ffff01010101ff, 0x01ffff010101ff01, 0x01ffff010101ffff,
|
||||
0x01ffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0x01ffff0101ffffff,
|
||||
0x01ffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0x01ffff01ff01ffff,
|
||||
0x01ffff01ffff0101, 0x01ffff01ffff01ff, 0x01ffff01ffffff01, 0x01ffff01ffffffff,
|
||||
0x01ffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0x01ffffff0101ffff,
|
||||
0x01ffffff01ff0101, 0x01ffffff01ff01ff, 0x01ffffff01ffff01, 0x01ffffff01ffffff,
|
||||
0x01ffffffff010101, 0x01ffffffff0101ff, 0x01ffffffff01ff01, 0x01ffffffff01ffff,
|
||||
0x01ffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0x01ffffffffffffff,
|
||||
0xff01010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0xff0101010101ffff,
|
||||
0xff01010101ff0101, 0xff01010101ff01ff, 0xff01010101ffff01, 0xff01010101ffffff,
|
||||
0xff010101ff010101, 0xff010101ff0101ff, 0xff010101ff01ff01, 0xff010101ff01ffff,
|
||||
0xff010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0xff010101ffffffff,
|
||||
0xff0101ff01010101, 0xff0101ff010101ff, 0xff0101ff0101ff01, 0xff0101ff0101ffff,
|
||||
0xff0101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0xff0101ff01ffffff,
|
||||
0xff0101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0xff0101ffff01ffff,
|
||||
0xff0101ffffff0101, 0xff0101ffffff01ff, 0xff0101ffffffff01, 0xff0101ffffffffff,
|
||||
0xff01ff0101010101, 0xff01ff01010101ff, 0xff01ff010101ff01, 0xff01ff010101ffff,
|
||||
0xff01ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0xff01ff0101ffffff,
|
||||
0xff01ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0xff01ff01ff01ffff,
|
||||
0xff01ff01ffff0101, 0xff01ff01ffff01ff, 0xff01ff01ffffff01, 0xff01ff01ffffffff,
|
||||
0xff01ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0xff01ffff0101ffff,
|
||||
0xff01ffff01ff0101, 0xff01ffff01ff01ff, 0xff01ffff01ffff01, 0xff01ffff01ffffff,
|
||||
0xff01ffffff010101, 0xff01ffffff0101ff, 0xff01ffffff01ff01, 0xff01ffffff01ffff,
|
||||
0xff01ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0xff01ffffffffffff,
|
||||
0xffff010101010101, 0xffff0101010101ff, 0xffff01010101ff01, 0xffff01010101ffff,
|
||||
0xffff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0xffff010101ffffff,
|
||||
0xffff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0xffff0101ff01ffff,
|
||||
0xffff0101ffff0101, 0xffff0101ffff01ff, 0xffff0101ffffff01, 0xffff0101ffffffff,
|
||||
0xffff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0xffff01ff0101ffff,
|
||||
0xffff01ff01ff0101, 0xffff01ff01ff01ff, 0xffff01ff01ffff01, 0xffff01ff01ffffff,
|
||||
0xffff01ffff010101, 0xffff01ffff0101ff, 0xffff01ffff01ff01, 0xffff01ffff01ffff,
|
||||
0xffff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0xffff01ffffffffff,
|
||||
0xffffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0xffffff010101ffff,
|
||||
0xffffff0101ff0101, 0xffffff0101ff01ff, 0xffffff0101ffff01, 0xffffff0101ffffff,
|
||||
0xffffff01ff010101, 0xffffff01ff0101ff, 0xffffff01ff01ff01, 0xffffff01ff01ffff,
|
||||
0xffffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0xffffff01ffffffff,
|
||||
0xffffffff01010101, 0xffffffff010101ff, 0xffffffff0101ff01, 0xffffffff0101ffff,
|
||||
0xffffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0xffffffff01ffffff,
|
||||
0xffffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0xffffffffff01ffff,
|
||||
0xffffffffffff0101, 0xffffffffffff01ff, 0xffffffffffffff01, 0xffffffffffffffff,
|
||||
};
|
||||
|
||||
struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
|
||||
DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
@@ -2064,15 +2136,79 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
|
||||
}
|
||||
|
||||
uint32x4x4_t data;
|
||||
struct Bits {
|
||||
uint8x16x4_t b1;
|
||||
uint8x16x4_t b2;
|
||||
};
|
||||
Bits bits;
|
||||
SimpleBits bits;
|
||||
|
||||
float d;
|
||||
};
|
||||
|
||||
|
||||
|
||||
struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
|
||||
DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
|
||||
d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
|
||||
|
||||
auto aux1 = vld1_u8(x[i].scales);
|
||||
auto aux2 = vshr_n_u8(aux1, 4);
|
||||
auto scales8 = vqtbl1q_u8(vandq_u8(vcombine_u8(aux1, aux2), vdupq_n_u8(0xf)), vreinterpretq_u8_u64(scale_shuffle));
|
||||
scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(scales8, 1), vdupq_n_u8(1)));
|
||||
int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) };
|
||||
return make_wider(scales16);
|
||||
}
|
||||
|
||||
static inline void make4(const uint8x16_t& signs16, uint8x16_t& shuffle, const uint8_t * qs, const uint8_t * qh,
|
||||
const uint8x16_t& smask, const uint8x16_t& step, const uint8x16_t& m1, uint8x16_t * b) {
|
||||
uint32_t aux32[2];
|
||||
const uint16_t * aux16 = (const uint16_t *)aux32;
|
||||
for (int k = 0; k < 2; ++k) {
|
||||
aux32[1] = (qh[k] << 4) | (qh[k] << 18);
|
||||
aux32[0] = (aux32[1] << 4) & 0x03000300;
|
||||
aux32[1] &= 0x03000300;
|
||||
b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))),
|
||||
vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1]))));
|
||||
auto aux1 = vqtbl1q_u8(signs16, shuffle);
|
||||
auto s1 = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux1, smask), smask), m1));
|
||||
b[2*k+0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[2*k+0]), s1));
|
||||
shuffle = vaddq_u8(shuffle, step);
|
||||
|
||||
b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))),
|
||||
vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3]))));
|
||||
auto aux2 = vqtbl1q_u8(signs16, shuffle);
|
||||
auto s2 = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux2, smask), smask), m1));
|
||||
b[2*k+1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[2*k+1]), s2));
|
||||
shuffle = vaddq_u8(shuffle, step);
|
||||
}
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
|
||||
const auto smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
|
||||
const auto m1 = vdupq_n_u8(1);
|
||||
const auto step = vdupq_n_u8(2);
|
||||
|
||||
const auto * qs = x[i].qs + 16*j;
|
||||
const auto * qh = x[i].qh + 4*j;
|
||||
const auto signs16 = vld1q_u8(qs + QK_K/8);
|
||||
|
||||
auto shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1));
|
||||
make4(signs16, shuffle, qs+0, qh+0, smask, step, m1, bits.b1.val);
|
||||
make4(signs16, shuffle, qs+8, qh+2, smask, step, m1, bits.b2.val);
|
||||
}
|
||||
|
||||
SimpleBits bits;
|
||||
|
||||
constexpr static const uint64x2_t scale_shuffle = { 0x0b030a0209010800, 0x0f070e060d050c04 };
|
||||
|
||||
float d;
|
||||
|
||||
};
|
||||
|
||||
|
||||
template <int nrc_y, typename Dequantizer>
|
||||
void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -2620,6 +2756,9 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
MulMat::set_functions<DequantizerIQ2XXS>(m);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
MulMat::set_functions<DequantizerIQ2S>(m);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
MulMat::set_functions<DequantizerQ40>(m);
|
||||
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
|
||||
|
||||
Reference in New Issue
Block a user