mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-08 07:20:12 +00:00
iqk_mul_mat: Arm implementation for iq2_xxs (llama.cpp version)
We get ~5% speeedup for TG-128, 3X for PP-512
This commit is contained in:
113
iqk_mul_mat.cpp
113
iqk_mul_mat.cpp
@@ -22,6 +22,9 @@
|
||||
#include "ggml-quants.h"
|
||||
#include "sgemm.h"
|
||||
|
||||
#define GGML_COMMON_IMPL_C
|
||||
#include "ggml-common.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
// This matrix - vector and matrix - matrix multiplication implementation
|
||||
@@ -1944,8 +1947,113 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
|
||||
float d;
|
||||
};
|
||||
|
||||
static const int8_t keven_signs_q2xs[1024] = {
|
||||
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
|
||||
1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
|
||||
1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
|
||||
1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
|
||||
1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
|
||||
1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
|
||||
1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
|
||||
1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
|
||||
1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
|
||||
1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
|
||||
1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
|
||||
1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
|
||||
1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
|
||||
1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
|
||||
1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
|
||||
1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
|
||||
1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
|
||||
1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
|
||||
1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
|
||||
1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
|
||||
1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
|
||||
1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
|
||||
1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
|
||||
1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
|
||||
1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
|
||||
1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
|
||||
1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
|
||||
1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
|
||||
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
};
|
||||
|
||||
struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
|
||||
DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 8; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
|
||||
d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
|
||||
auto tmp = vld1q_u32_x4((const uint32_t *)x[i].qs);
|
||||
data.val[0] = vuzp1q_u32(tmp.val[0], tmp.val[1]); // codebook indices for blocks 0...3
|
||||
data.val[1] = vuzp1q_u32(tmp.val[2], tmp.val[3]); // codebook indices for blocks 4...7
|
||||
data.val[2] = vuzp2q_u32(tmp.val[0], tmp.val[1]); // scales and signs for blocks 0...3
|
||||
data.val[3] = vuzp2q_u32(tmp.val[2], tmp.val[3]); // scales and signs for blocks 4...7
|
||||
int32x4x2_t scales;
|
||||
scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(data.val[2], 28), 1), vdupq_n_u32(1)));
|
||||
scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(data.val[3], 28), 1), vdupq_n_u32(1)));
|
||||
return scales;
|
||||
}
|
||||
|
||||
inline void prepare(int /*i*/, int j) {
|
||||
const uint8_t * idx = (const uint8_t *)(data.val + j);
|
||||
bits.b1.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 0])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 1])));
|
||||
bits.b1.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 2])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 3])));
|
||||
bits.b1.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 4])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 5])));
|
||||
bits.b1.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 6])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 7])));
|
||||
bits.b2.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 8])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[ 9])));
|
||||
bits.b2.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[10])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[11])));
|
||||
bits.b2.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[12])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[13])));
|
||||
bits.b2.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2xxs_grid + idx[14])), vld1_s8((const int8_t *)(iq2xxs_grid + idx[15])));
|
||||
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
||||
const uint32_t * sidx = (const uint32_t *)(data.val + 2 + j);
|
||||
bits.b1.val[0] = vmulq_s8(bits.b1.val[0], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[0] >> 0) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[0] >> 7) & 127)))));
|
||||
bits.b1.val[1] = vmulq_s8(bits.b1.val[1], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[0] >> 14) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[0] >> 21) & 127)))));
|
||||
bits.b1.val[2] = vmulq_s8(bits.b1.val[2], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[1] >> 0) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[1] >> 7) & 127)))));
|
||||
bits.b1.val[3] = vmulq_s8(bits.b1.val[3], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[1] >> 14) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[1] >> 21) & 127)))));
|
||||
bits.b2.val[0] = vmulq_s8(bits.b2.val[0], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[2] >> 0) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[2] >> 7) & 127)))));
|
||||
bits.b2.val[1] = vmulq_s8(bits.b2.val[1], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[2] >> 14) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[2] >> 21) & 127)))));
|
||||
bits.b2.val[2] = vmulq_s8(bits.b2.val[2], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[3] >> 0) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[3] >> 7) & 127)))));
|
||||
bits.b2.val[3] = vmulq_s8(bits.b2.val[3], vcombine_s8(vld1_s8((const int8_t *)(signs64 + ((sidx[3] >> 14) & 127))), vld1_s8((const int8_t *)(signs64 + ((sidx[3] >> 21) & 127)))));
|
||||
//auto mask = vdupq_n_u32(127);
|
||||
//uint32x4_t sindex;
|
||||
//sindex = vandq_u32(data.val[2+j], mask);
|
||||
//mask = vshlq_n_u32(mask, 7);
|
||||
//sindex = vorrq_u32(sindex, vshlq_n_u32(vandq_u32(data.val[2+j], mask), 1));
|
||||
//mask = vshlq_n_u32(mask, 7);
|
||||
//sindex = vorrq_u32(sindex, vshlq_n_u32(vandq_u32(data.val[2+j], mask), 2));
|
||||
//mask = vshlq_n_u32(mask, 7);
|
||||
//sindex = vorrq_u32(sindex, vshlq_n_u32(vandq_u32(data.val[2+j], mask), 3));
|
||||
//const uint8_t * sidx = (const uint8_t *)&sindex;
|
||||
//bits.b1.val[0] = vmulq_s8(bits.b1.val[0], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[ 0])), vld1_s8((const int8_t *)(signs64 + sidx[ 1]))));
|
||||
//bits.b1.val[1] = vmulq_s8(bits.b1.val[1], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[ 2])), vld1_s8((const int8_t *)(signs64 + sidx[ 3]))));
|
||||
//bits.b1.val[2] = vmulq_s8(bits.b1.val[2], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[ 4])), vld1_s8((const int8_t *)(signs64 + sidx[ 5]))));
|
||||
//bits.b1.val[3] = vmulq_s8(bits.b1.val[3], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[ 6])), vld1_s8((const int8_t *)(signs64 + sidx[ 7]))));
|
||||
//bits.b2.val[0] = vmulq_s8(bits.b2.val[0], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[ 8])), vld1_s8((const int8_t *)(signs64 + sidx[ 9]))));
|
||||
//bits.b2.val[1] = vmulq_s8(bits.b2.val[1], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[10])), vld1_s8((const int8_t *)(signs64 + sidx[11]))));
|
||||
//bits.b2.val[2] = vmulq_s8(bits.b2.val[2], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[12])), vld1_s8((const int8_t *)(signs64 + sidx[13]))));
|
||||
//bits.b2.val[3] = vmulq_s8(bits.b2.val[3], vcombine_s8(vld1_s8((const int8_t *)(signs64 + sidx[14])), vld1_s8((const int8_t *)(signs64 + sidx[15]))));
|
||||
}
|
||||
|
||||
uint32x4x4_t data;
|
||||
struct Bits {
|
||||
uint8x16x4_t b1;
|
||||
uint8x16x4_t b2;
|
||||
};
|
||||
Bits bits;
|
||||
|
||||
float d;
|
||||
};
|
||||
|
||||
template <int nrc_y, typename Dequantizer>
|
||||
static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
const int nb = n / QK_K;
|
||||
|
||||
@@ -2466,6 +2574,9 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
MulMat::set_functions<DequantizerIQ4XS>(m);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
MulMat::set_functions<DequantizerIQ2XXS>(m);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
MulMat::set_functions<DequantizerQ40>(m);
|
||||
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);
|
||||
|
||||
Reference in New Issue
Block a user