iqk_mul_mat: Arm implementation for iq2_xxs (llama.cpp version)

We get ~5% speeedup for TG-128, 3X for PP-512
This commit is contained in:
Iwan Kawrakow
2024-05-27 13:38:26 +02:00
parent b51922530f
commit d7ab97149f

View File

@@ -22,6 +22,9 @@
#include "ggml-quants.h"
#include "sgemm.h"
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
// clang-format off
// This matrix - vector and matrix - matrix multiplication implementation
@@ -1944,8 +1947,113 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
float d;
};
static const int8_t keven_signs_q2xs[1024] = {
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
};
struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 8; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
auto tmp = vld1q_u32_x4((const uint32_t *)x[i].qs);
data.val[0] = vuzp1q_u32(tmp.val[0], tmp.val[1]); // codebook indices for blocks 0...3
data.val[1] = vuzp1q_u32(tmp.val[2], tmp.val[3]); // codebook indices for blocks 4...7
data.val[2] = vuzp2q_u32(tmp.val[0], tmp.val[1]); // scales and signs for blocks 0...3
data.val[3] = vuzp2q_u32(tmp.val[2], tmp.val[3]); // scales and signs for blocks 4...7
int32x4x2_t scales;
scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(data.val[2], 28), 1), vdupq_n_u32(1)));
scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(data.val[3], 28), 1), vdupq_n_u32(1)));
return scales;
}
inline void prepare(int /*i*/, int j) {
const uint8_t * idx = (const uint8_t *)(data.val + j);
bits.b1.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 0])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 1])));
bits.b1.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 2])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 3])));
bits.b1.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 4])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 5])));
bits.b1.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 6])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 7])));
bits.b2.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 8])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 9])));
bits.b2.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[10])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[11])));
bits.b2.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[12])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[13])));
bits.b2.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[14])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[15])));
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
const uint32_t * sidx = (const uint32_t *)(data.val + 2 + j);
bits.b1.val[0] = vmulq_s8(bits.b1.val[0], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[0] >> 0) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[0] >> 7) & 127)))));
bits.b1.val[1] = vmulq_s8(bits.b1.val[1], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[0] >> 14) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[0] >> 21) & 127)))));
bits.b1.val[2] = vmulq_s8(bits.b1.val[2], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[1] >> 0) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[1] >> 7) & 127)))));
bits.b1.val[3] = vmulq_s8(bits.b1.val[3], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[1] >> 14) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[1] >> 21) & 127)))));
bits.b2.val[0] = vmulq_s8(bits.b2.val[0], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[2] >> 0) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[2] >> 7) & 127)))));
bits.b2.val[1] = vmulq_s8(bits.b2.val[1], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[2] >> 14) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[2] >> 21) & 127)))));
bits.b2.val[2] = vmulq_s8(bits.b2.val[2], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[3] >> 0) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[3] >> 7) & 127)))));
bits.b2.val[3] = vmulq_s8(bits.b2.val[3], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[3] >> 14) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[3] >> 21) & 127)))));
//auto mask = vdupq_n_u32(127);
//uint32x4_t sindex;
//sindex = vandq_u32(data.val[2+j], mask);
//mask = vshlq_n_u32(mask, 7);
//sindex = vorrq_u32(sindex, vshlq_n_u32(vandq_u32(data.val[2+j], mask), 1));
//mask = vshlq_n_u32(mask, 7);
//sindex = vorrq_u32(sindex, vshlq_n_u32(vandq_u32(data.val[2+j], mask), 2));
//mask = vshlq_n_u32(mask, 7);
//sindex = vorrq_u32(sindex, vshlq_n_u32(vandq_u32(data.val[2+j], mask), 3));
//const uint8_t * sidx = (const uint8_t *)&sindex;
//bits.b1.val[0] = vmulq_s8(bits.b1.val[0], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[ 0])), vld1_s8((const int8_t *)(signs64 + sidx[ 1]))));
//bits.b1.val[1] = vmulq_s8(bits.b1.val[1], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[ 2])), vld1_s8((const int8_t *)(signs64 + sidx[ 3]))));
//bits.b1.val[2] = vmulq_s8(bits.b1.val[2], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[ 4])), vld1_s8((const int8_t *)(signs64 + sidx[ 5]))));
//bits.b1.val[3] = vmulq_s8(bits.b1.val[3], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[ 6])), vld1_s8((const int8_t *)(signs64 + sidx[ 7]))));
//bits.b2.val[0] = vmulq_s8(bits.b2.val[0], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[ 8])), vld1_s8((const int8_t *)(signs64 + sidx[ 9]))));
//bits.b2.val[1] = vmulq_s8(bits.b2.val[1], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[10])), vld1_s8((const int8_t *)(signs64 + sidx[11]))));
//bits.b2.val[2] = vmulq_s8(bits.b2.val[2], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[12])), vld1_s8((const int8_t *)(signs64 + sidx[13]))));
//bits.b2.val[3] = vmulq_s8(bits.b2.val[3], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[14])), vld1_s8((const int8_t *)(signs64 + sidx[15]))));
}
uint32x4x4_t data;
struct Bits {
uint8x16x4_t b1;
uint8x16x4_t b2;
};
Bits bits;
float d;
};
template <int nrc_y, typename Dequantizer>
static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
@@ -2466,6 +2574,9 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /
case GGML_TYPE_IQ4_XS:
MulMat::set_functions<DequantizerIQ4XS>(m);
break;
case GGML_TYPE_IQ2_XXS:
MulMat::set_functions<DequantizerIQ2XXS>(m);
break;
case GGML_TYPE_Q4_0:
MulMat::set_functions<DequantizerQ40>(m);
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);