mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
Trellis quants: faster CPU prompt processing (#482)
* Experimenting with dequant + f32 GEMM For iq4_kt this results in a massive PP improvement from PP512 = ~42 t/s to PP512 = 128 t/s. * Experimenting with dequant + f32 GEMM iq2_kt: from PP512 = 57.3 t/s to PP512 = 135.0 t/s iq3_kt: from PP512 = 43.8 t/s to PP512 = 131.4 t/s * Experimenting with dequant + f16 GEMM on NEON iq2_kt: PP512 = 79 t/s from 42 t/s iq3_kt: PP512 = 81 t/s from 35 t/s Also, found the reason why the f16 implementation for iq4_kt was not working: it overflows. It works after mltiplying with the row scale before doing the multiply-adds. * Experimenting with dequant + f16 GEMM on NEON iq4_kt: PP512 = 86 t/s from 29 t/s * Minor --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -1618,8 +1618,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||||||
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_kt_ref,
|
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_kt_ref,
|
||||||
.vec_dot = vec_dot_iq4_kt_q8_k,
|
.vec_dot = vec_dot_iq4_kt_q8_k,
|
||||||
#ifdef __ARM_NEON
|
#ifdef __ARM_NEON
|
||||||
//.vec_dot_type = GGML_TYPE_F16,
|
.vec_dot_type = GGML_TYPE_F16,
|
||||||
.vec_dot_type = GGML_TYPE_F32,
|
|
||||||
#else
|
#else
|
||||||
.vec_dot_type = GGML_TYPE_F32,
|
.vec_dot_type = GGML_TYPE_F32,
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static inline uint32_t trellis_next(uint32_t& val) {
|
inline uint32_t trellis_next(uint32_t& val) {
|
||||||
constexpr uint32_t ka = 89226354;
|
constexpr uint32_t ka = 89226354;
|
||||||
constexpr uint32_t kb = 64248484;
|
constexpr uint32_t kb = 64248484;
|
||||||
constexpr uint32_t kmask = 0x8fff8fff;
|
constexpr uint32_t kmask = 0x8fff8fff;
|
||||||
@@ -22,7 +22,7 @@ static inline uint32_t trellis_next(uint32_t& val) {
|
|||||||
return (val & kmask) ^ km32;
|
return (val & kmask) ^ km32;
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline float trellis_gen(uint32_t& val, uint32_t* s) {
|
inline float trellis_gen(uint32_t& val, uint32_t* s) {
|
||||||
const ggml_fp16_t * h = (const ggml_fp16_t *)s;
|
const ggml_fp16_t * h = (const ggml_fp16_t *)s;
|
||||||
s[0] = trellis_next(val);
|
s[0] = trellis_next(val);
|
||||||
return GGML_FP16_TO_FP32(h[0]) + GGML_FP16_TO_FP32(h[1]);
|
return GGML_FP16_TO_FP32(h[0]) + GGML_FP16_TO_FP32(h[1]);
|
||||||
@@ -59,7 +59,7 @@ struct Trellis1 {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline __m256 trellis_gen8(__m256i i8) {
|
inline __m256 trellis_gen8(__m256i i8) {
|
||||||
// split upper and lower bits of each 32-bit lane into two 8xfloat16 `hlo`, `hhi`
|
// split upper and lower bits of each 32-bit lane into two 8xfloat16 `hlo`, `hhi`
|
||||||
__m256i low_16_bits_mask = _mm256_set1_epi32(0x0000FFFF);
|
__m256i low_16_bits_mask = _mm256_set1_epi32(0x0000FFFF);
|
||||||
__m256i lower_halves_lanes32 = _mm256_and_si256(i8, low_16_bits_mask);
|
__m256i lower_halves_lanes32 = _mm256_and_si256(i8, low_16_bits_mask);
|
||||||
@@ -97,8 +97,47 @@ struct Trellis2 {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) {
|
||||||
|
GGML_ASSERT(n%QK_K == 0);
|
||||||
|
const int nb = n/QK_K;
|
||||||
|
|
||||||
|
Trellis1 trellis;
|
||||||
|
|
||||||
|
auto shifts = _mm_set_epi32(0, 0, 4, 0);
|
||||||
|
auto values = _mm_loadu_si128((const __m128i *)iq4k_values);
|
||||||
|
|
||||||
|
union { __m256 vec; float val[8]; } s_helper;
|
||||||
|
|
||||||
|
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||||
|
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||||
|
auto d = _mm256_set1_ps(*dptr * 31.75f * 1.05f);
|
||||||
|
const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1);
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const uint16_t * ql = (const uint16_t *)x[i].ql;
|
||||||
|
auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales);
|
||||||
|
s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
|
||||||
|
s8 = _mm_shuffle_epi8(values, s8);
|
||||||
|
auto s32 = _mm256_cvtepi8_epi32(s8);
|
||||||
|
s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32));
|
||||||
|
for (int ib = 0; ib < QK_K/64; ++ib) {
|
||||||
|
auto scale1 = _mm256_set1_ps(s_helper.val[2*ib+0]);
|
||||||
|
auto scale2 = _mm256_set1_ps(s_helper.val[2*ib+1]);
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
auto xval1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(ql[8*ib+j+0]+4096)));
|
||||||
|
auto xval2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(ql[8*ib+j+4]+4096)));
|
||||||
|
_mm256_storeu_ps(y + i*QK_K + 64*ib + 8*j + 0, xval1);
|
||||||
|
_mm256_storeu_ps(y + i*QK_K + 64*ib + 8*j + 32, xval2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
y += stride_y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <int nrc_y>
|
template <int nrc_y>
|
||||||
static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||||
assert(n%QK_K == 0);
|
assert(n%QK_K == 0);
|
||||||
const int nb = n/QK_K;
|
const int nb = n/QK_K;
|
||||||
|
|
||||||
@@ -159,14 +198,63 @@ static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline __m256 abs_ps(__m256 vals) {
|
inline __m256 abs_ps(__m256 vals) {
|
||||||
// Clear sign-bit of all the 32-bit floats in vals
|
// Clear sign-bit of all the 32-bit floats in vals
|
||||||
__m256 sign_bit = _mm256_set1_ps(-0.0f);
|
__m256 sign_bit = _mm256_set1_ps(-0.0f);
|
||||||
return _mm256_andnot_ps(sign_bit, vals);
|
return _mm256_andnot_ps(sign_bit, vals);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void iqk_dequantize_iq3_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) {
|
||||||
|
GGML_ASSERT(n%QK_K == 0);
|
||||||
|
const int nb = n/QK_K;
|
||||||
|
|
||||||
|
Trellis1 trellis;
|
||||||
|
|
||||||
|
union { __m256 vec; float val[8]; } s_helper;
|
||||||
|
|
||||||
|
auto shifts = _mm_set_epi32(0, 0, 4, 0);
|
||||||
|
|
||||||
|
__m256i all_signs[4];
|
||||||
|
auto mask1 = _mm256_set1_epi32(0x01);
|
||||||
|
auto mask2 = _mm256_set1_epi32(0x10);
|
||||||
|
|
||||||
|
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||||
|
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||||
|
auto d = _mm256_set1_ps(*dptr * 31.75f * 1.015f);
|
||||||
|
const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1);
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const uint16_t * ql = (const uint16_t *)x[i].ql;
|
||||||
|
const uint8_t * qh = x[i].qh;
|
||||||
|
auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales);
|
||||||
|
s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
|
||||||
|
auto s32 = _mm256_cvtepi8_epi32(s8);
|
||||||
|
s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32));
|
||||||
|
for (int j = 0; j < 4; ++j) all_signs[j] = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qh + 8*j)));
|
||||||
|
for (int ib = 0; ib < 4; ++ib) {
|
||||||
|
auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]);
|
||||||
|
auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]);
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
uint32_t val1 = ql[4*ib+j ] + 4096;
|
||||||
|
uint32_t val2 = ql[4*ib+j+16] + 4096;
|
||||||
|
auto sign1 = _mm256_and_si256(_mm256_cmpeq_epi32(_mm256_and_si256(all_signs[j], mask1), mask1), _mm256_set1_epi32(0x80000000));
|
||||||
|
auto sign2 = _mm256_and_si256(_mm256_cmpeq_epi32(_mm256_and_si256(all_signs[j], mask2), mask2), _mm256_set1_epi32(0x80000000));
|
||||||
|
all_signs[j] = _mm256_srli_epi32(all_signs[j], 1);
|
||||||
|
auto x_val1 = abs_ps(trellis_gen8(trellis.next8(val1)));
|
||||||
|
auto x_val2 = abs_ps(trellis_gen8(trellis.next8(val2)));
|
||||||
|
x_val1 = _mm256_mul_ps(scale1, _mm256_xor_ps(x_val1, _mm256_castsi256_ps(sign1)));
|
||||||
|
x_val2 = _mm256_mul_ps(scale2, _mm256_xor_ps(x_val2, _mm256_castsi256_ps(sign2)));
|
||||||
|
_mm256_storeu_ps(y + i*QK_K+32*ib+8*j , x_val1);
|
||||||
|
_mm256_storeu_ps(y + i*QK_K+32*ib+8*j+128, x_val2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
y += stride_y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <int nrc_y>
|
template <int nrc_y>
|
||||||
static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||||
assert(n%QK_K == 0);
|
assert(n%QK_K == 0);
|
||||||
const int nb = n/QK_K;
|
const int nb = n/QK_K;
|
||||||
|
|
||||||
@@ -227,8 +315,57 @@ static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) {
|
||||||
|
GGML_ASSERT(n%QK_K == 0);
|
||||||
|
const int nb = n/QK_K;
|
||||||
|
constexpr int kNumGroups = 64;
|
||||||
|
|
||||||
|
Trellis2 trellis;
|
||||||
|
|
||||||
|
union { __m256 vec; float val[8]; } s_helper;
|
||||||
|
union { __m256i vec; uint32_t val[8]; } o_helper;
|
||||||
|
|
||||||
|
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||||
|
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||||
|
auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f);
|
||||||
|
auto dav = _mm256_set1_ps(dptr[1]);
|
||||||
|
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs);
|
||||||
|
const uint32_t * shb = x[i].qs;
|
||||||
|
const uint8_t * ql = (const uint8_t *)(shb + 8);
|
||||||
|
const uint8_t * qh = ql + kNumGroups;
|
||||||
|
auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1);
|
||||||
|
s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64))));
|
||||||
|
o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096));
|
||||||
|
for (int ib = 0; ib < 4; ++ib) {
|
||||||
|
auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]);
|
||||||
|
auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]);
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
|
||||||
|
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
|
||||||
|
uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
|
||||||
|
uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
|
||||||
|
uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
|
||||||
|
uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
|
||||||
|
auto x_val1 = _mm256_fmadd_ps(scale1, trellis_gen8(trellis.next8(val1, val3)), dav);
|
||||||
|
auto x_val2 = _mm256_fmadd_ps(scale2, trellis_gen8(trellis.next8(val2, val4)), dav);
|
||||||
|
|
||||||
|
_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j, x_val1);
|
||||||
|
_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j + QK_K/2, x_val2);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
y += stride_y;
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <int nrc_y>
|
template <int nrc_y>
|
||||||
static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||||
assert(n%QK_K == 0);
|
assert(n%QK_K == 0);
|
||||||
const int nb = n/QK_K;
|
const int nb = n/QK_K;
|
||||||
constexpr int kNumGroups = 64;
|
constexpr int kNumGroups = 64;
|
||||||
@@ -333,6 +470,16 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, size_t stride_y, int nrc_x) {
|
||||||
|
switch (type) {
|
||||||
|
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
|
||||||
|
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
|
||||||
|
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break;
|
||||||
|
default: return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
#else // !__x86_64__
|
#else // !__x86_64__
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@@ -403,8 +550,52 @@ struct Trellis1 {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, float16_t * y, size_t stride_y, int nrc_x) {
|
||||||
|
GGML_ASSERT(n%QK_K == 0);
|
||||||
|
const int nb = n/QK_K;
|
||||||
|
|
||||||
|
Trellis1 trellis;
|
||||||
|
|
||||||
|
auto values = vld1q_s8(iq4k_values);
|
||||||
|
|
||||||
|
union { float16x8_t vec; float16_t val[8]; } s_helper;
|
||||||
|
|
||||||
|
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||||
|
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||||
|
const float d = *dptr * 31.75f * 1.05f;
|
||||||
|
auto vd = vdupq_n_f32(d);
|
||||||
|
const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1);
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const uint16_t * ql = (const uint16_t *)x[i].ql;
|
||||||
|
auto u32 = *(const uint32_t *)x[i].scales;
|
||||||
|
auto s8_u32 = uint32x2_t{u32, u32 >> 4};
|
||||||
|
s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f));
|
||||||
|
auto s8 = vqtbl1_s8(values, vreinterpret_u8_u32(s8_u32));
|
||||||
|
auto s16 = vmovl_s8(s8);
|
||||||
|
auto s32l = vmovl_s16(vget_low_s16 (s16));
|
||||||
|
auto s32h = vmovl_s16(vget_high_s16(s16));
|
||||||
|
auto f32l = vmulq_f32(vd, vcvtq_f32_s32(s32l));
|
||||||
|
auto f32h = vmulq_f32(vd, vcvtq_f32_s32(s32h));
|
||||||
|
s_helper.vec = vcombine_f16(vcvt_f16_f32(f32l), vcvt_f16_f32(f32h));
|
||||||
|
for (int ib = 0; ib < QK_K/64; ++ib) {
|
||||||
|
auto scale1 = vdupq_n_f16(s_helper.val[2*ib+0]);
|
||||||
|
auto scale2 = vdupq_n_f16(s_helper.val[2*ib+1]);
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
auto xval1 = vmulq_f16(scale1, trellis.gen8(ql[8*ib+j+0]+4096));
|
||||||
|
auto xval2 = vmulq_f16(scale2, trellis.gen8(ql[8*ib+j+4]+4096));
|
||||||
|
vst1q_f16(y + i*QK_K + 64*ib + 8*j + 0, xval1);
|
||||||
|
vst1q_f16(y + i*QK_K + 64*ib + 8*j + 32, xval2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
y += stride_y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <int nrc_y>
|
template <int nrc_y>
|
||||||
static void mul_mat_iq2_kt_F16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
void mul_mat_iq2_kt_F16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||||
assert(n%QK_K == 0);
|
assert(n%QK_K == 0);
|
||||||
const int nb = n/QK_K;
|
const int nb = n/QK_K;
|
||||||
|
|
||||||
@@ -466,8 +657,61 @@ static void mul_mat_iq2_kt_F16_T(int n, const void * vx, size_t bx, const DataIn
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void iqk_dequantize_iq3_kt(int n, const void * vx, size_t bx, float16_t * y, size_t stride_y, int nrc_x) {
|
||||||
|
GGML_ASSERT(n%QK_K == 0);
|
||||||
|
const int nb = n/QK_K;
|
||||||
|
|
||||||
|
Trellis1 trellis;
|
||||||
|
|
||||||
|
union { float16x8_t vec; float16_t val[8]; } s_helper;
|
||||||
|
|
||||||
|
uint16x8_t all_signs[4];
|
||||||
|
auto mask1 = vdupq_n_u16(0x01);
|
||||||
|
auto mask2 = vdupq_n_u16(0x10);
|
||||||
|
|
||||||
|
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||||
|
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||||
|
const float d = *dptr * 31.75f * 1.015f;
|
||||||
|
auto vd = vdupq_n_f32(d);
|
||||||
|
const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1);
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const uint16_t * ql = (const uint16_t *)x[i].ql;
|
||||||
|
const uint8_t * qh = x[i].qh;
|
||||||
|
auto u32 = *(const uint32_t *)x[i].scales;
|
||||||
|
auto s8_u32 = uint32x2_t{u32, u32 >> 4};
|
||||||
|
s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f));
|
||||||
|
auto s16 = vmovl_s8(vreinterpret_s8_u32(s8_u32));
|
||||||
|
auto s32l = vmovl_s16(vget_low_s16 (s16));
|
||||||
|
auto s32h = vmovl_s16(vget_high_s16(s16));
|
||||||
|
auto f32l = vmulq_f32(vd, vcvtq_f32_s32(s32l));
|
||||||
|
auto f32h = vmulq_f32(vd, vcvtq_f32_s32(s32h));
|
||||||
|
s_helper.vec = vcombine_f16(vcvt_f16_f32(f32l), vcvt_f16_f32(f32h));
|
||||||
|
for (int j = 0; j < 4; ++j) all_signs[j] = vmovl_u8(vld1_u8(qh + 8*j));
|
||||||
|
for (int ib = 0; ib < 4; ++ib) {
|
||||||
|
auto scale1 = vdupq_n_f16(s_helper.val[ib+0]);
|
||||||
|
auto scale2 = vdupq_n_f16(s_helper.val[ib+4]);
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
uint32_t val1 = ql[4*ib+j ] + 4096;
|
||||||
|
uint32_t val2 = ql[4*ib+j+16] + 4096;
|
||||||
|
auto sign1 = vshlq_n_u16(vandq_u16(all_signs[j], mask1), 15);
|
||||||
|
auto sign2 = vshlq_n_u16(vandq_u16(all_signs[j], mask2), 11);
|
||||||
|
all_signs[j] = vshrq_n_u16(all_signs[j], 1);
|
||||||
|
auto x_val1 = vabsq_f16(trellis.gen8(val1));
|
||||||
|
auto x_val2 = vabsq_f16(trellis.gen8(val2));
|
||||||
|
x_val1 = vmulq_f16(scale1, vreinterpretq_f16_u16(vorrq_u16(vreinterpretq_u16_f16(x_val1), sign1)));
|
||||||
|
x_val2 = vmulq_f16(scale2, vreinterpretq_f16_u16(vorrq_u16(vreinterpretq_u16_f16(x_val2), sign2)));
|
||||||
|
vst1q_f16(y + i*QK_K+32*ib+8*j , x_val1);
|
||||||
|
vst1q_f16(y + i*QK_K+32*ib+8*j+128, x_val2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
y += stride_y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <int nrc_y>
|
template <int nrc_y>
|
||||||
static void mul_mat_iq3_kt_F16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
void mul_mat_iq3_kt_F16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||||
assert(n%QK_K == 0);
|
assert(n%QK_K == 0);
|
||||||
const int nb = n/QK_K;
|
const int nb = n/QK_K;
|
||||||
|
|
||||||
@@ -527,8 +771,63 @@ static void mul_mat_iq3_kt_F16_T(int n, const void * vx, size_t bx, const DataIn
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float16_t * y, size_t stride_y, int nrc_x) {
|
||||||
|
GGML_ASSERT(n%QK_K == 0);
|
||||||
|
const int nb = n/QK_K;
|
||||||
|
constexpr int kNumGroups = 64;
|
||||||
|
|
||||||
|
Trellis1 trellis;
|
||||||
|
|
||||||
|
union { float16x8_t vec; float16_t val[8]; } s_helper;
|
||||||
|
union { uint16x8_t vec; uint16_t val[8]; } o_helper;
|
||||||
|
|
||||||
|
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||||
|
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||||
|
auto d = dptr[0] * 31.75f * 1.01f;
|
||||||
|
//auto dav = dptr[1];
|
||||||
|
// Something goes wrong when we add the average. Why?
|
||||||
|
//auto vav = std::abs(dav) > 0.00006103515625f ? vdupq_n_f16(GGML_FP32_TO_FP16(dav)) : vdupq_n_f16(0);
|
||||||
|
auto vd = vdupq_n_f32(d);
|
||||||
|
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const uint32_t * shb = x[i].qs;
|
||||||
|
auto vshb = vld1q_u32_x2(shb);
|
||||||
|
auto vshb16 = vcombine_u16(vmovn_u32(vandq_u32(vshb.val[0], vdupq_n_u32(0xff))), vmovn_u32(vandq_u32(vshb.val[1], vdupq_n_u32(0xff))));
|
||||||
|
const uint8_t * ql = (const uint8_t *)(shb + 8);
|
||||||
|
const uint8_t * qh = ql + kNumGroups;
|
||||||
|
auto iscales = vsubq_s16(vreinterpretq_s16_u16(vshrq_n_u16(vshb16, 1)), vdupq_n_s16(64));
|
||||||
|
auto s32l = vmovl_s16(vget_low_s16(iscales));
|
||||||
|
auto s32h = vmovl_s16(vget_high_s16(iscales));
|
||||||
|
auto f32l = vmulq_f32(vd, vcvtq_f32_s32(s32l));
|
||||||
|
auto f32h = vmulq_f32(vd, vcvtq_f32_s32(s32h));
|
||||||
|
s_helper.vec = vcombine_f16(vcvt_f16_f32(f32l), vcvt_f16_f32(f32h));
|
||||||
|
o_helper.vec = vaddq_u16(vshlq_n_u16(vandq_u16(vshb16, vdupq_n_u16(1)), 15), vdupq_n_u16(4096));
|
||||||
|
for (int ib = 0; ib < 4; ++ib) {
|
||||||
|
auto scale1 = vdupq_n_f16(s_helper.val[ib+0]);
|
||||||
|
auto scale2 = vdupq_n_f16(s_helper.val[ib+4]);
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
|
||||||
|
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
|
||||||
|
uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
|
||||||
|
uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
|
||||||
|
uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
|
||||||
|
uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
|
||||||
|
//auto x_val1 = vfmaq_f16(vav, scale1, trellis.gen8(val1, val3));
|
||||||
|
//auto x_val2 = vfmaq_f16(vav, scale2, trellis.gen8(val2, val4));
|
||||||
|
auto x_val1 = vmulq_f16(scale1, trellis.gen8(val1, val3));
|
||||||
|
auto x_val2 = vmulq_f16(scale2, trellis.gen8(val2, val4));
|
||||||
|
vst1q_f16(y + i*QK_K+32*ib+8*j+ 0, x_val1);
|
||||||
|
vst1q_f16(y + i*QK_K+32*ib+8*j+128, x_val2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
y += stride_y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <int nrc_y>
|
template <int nrc_y>
|
||||||
static void mul_mat_iq4_kt_F16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
void mul_mat_iq4_kt_F16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||||
assert(n%QK_K == 0);
|
assert(n%QK_K == 0);
|
||||||
const int nb = n/QK_K;
|
const int nb = n/QK_K;
|
||||||
constexpr int kNumGroups = 64;
|
constexpr int kNumGroups = 64;
|
||||||
@@ -548,8 +847,6 @@ static void mul_mat_iq4_kt_F16_T(int n, const void * vx, size_t bx, const DataIn
|
|||||||
auto sum = vdupq_n_f16(0);
|
auto sum = vdupq_n_f16(0);
|
||||||
for (int i = 0; i < n/8; ++i) sum = vaddq_f16(sum, vld1q_f16(y[iy] + 8*i));
|
for (int i = 0; i < n/8; ++i) sum = vaddq_f16(sum, vld1q_f16(y[iy] + 8*i));
|
||||||
auto sum32 = vaddq_f32(vcvt_f32_f16(vget_low_f16(sum)), vcvt_f32_f16(vget_high_f16(sum)));
|
auto sum32 = vaddq_f32(vcvt_f32_f16(vget_low_f16(sum)), vcvt_f32_f16(vget_high_f16(sum)));
|
||||||
//auto sum32 = vdupq_n_f32(0);
|
|
||||||
//for (int i = 0; i < n/4; ++i) sum32 = vaddq_f32(sum32, vcvt_f32_f16(vld1_f16(y[iy] + 4*i)));
|
|
||||||
row_sum[iy] = vaddvq_f32(sum32);
|
row_sum[iy] = vaddvq_f32(sum32);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -557,6 +854,7 @@ static void mul_mat_iq4_kt_F16_T(int n, const void * vx, size_t bx, const DataIn
|
|||||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||||
auto d = dptr[0] * 31.75f * 1.01f;
|
auto d = dptr[0] * 31.75f * 1.01f;
|
||||||
auto dav = dptr[1];
|
auto dav = dptr[1];
|
||||||
|
auto vd = vdupq_n_f32(d);
|
||||||
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
|
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
|
||||||
|
|
||||||
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f16(0);
|
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f16(0);
|
||||||
@@ -568,7 +866,12 @@ static void mul_mat_iq4_kt_F16_T(int n, const void * vx, size_t bx, const DataIn
|
|||||||
const uint8_t * ql = (const uint8_t *)(shb + 8);
|
const uint8_t * ql = (const uint8_t *)(shb + 8);
|
||||||
const uint8_t * qh = ql + kNumGroups;
|
const uint8_t * qh = ql + kNumGroups;
|
||||||
auto iscales = vsubq_s16(vreinterpretq_s16_u16(vshrq_n_u16(vshb16, 1)), vdupq_n_s16(64));
|
auto iscales = vsubq_s16(vreinterpretq_s16_u16(vshrq_n_u16(vshb16, 1)), vdupq_n_s16(64));
|
||||||
s_helper.vec = vcvtq_f16_s16(iscales);
|
auto s32l = vmovl_s16(vget_low_s16(iscales));
|
||||||
|
auto s32h = vmovl_s16(vget_high_s16(iscales));
|
||||||
|
auto f32l = vmulq_f32(vd, vcvtq_f32_s32(s32l));
|
||||||
|
auto f32h = vmulq_f32(vd, vcvtq_f32_s32(s32h));
|
||||||
|
s_helper.vec = vcombine_f16(vcvt_f16_f32(f32l), vcvt_f16_f32(f32h));
|
||||||
|
//s_helper.vec = vcvtq_f16_s16(iscales);
|
||||||
o_helper.vec = vaddq_u16(vshlq_n_u16(vandq_u16(vshb16, vdupq_n_u16(1)), 15), vdupq_n_u16(4096));
|
o_helper.vec = vaddq_u16(vshlq_n_u16(vandq_u16(vshb16, vdupq_n_u16(1)), 15), vdupq_n_u16(4096));
|
||||||
for (int ib = 0; ib < 4; ++ib) {
|
for (int ib = 0; ib < 4; ++ib) {
|
||||||
auto scale1 = vdupq_n_f16(s_helper.val[ib+0]);
|
auto scale1 = vdupq_n_f16(s_helper.val[ib+0]);
|
||||||
@@ -602,18 +905,18 @@ static void mul_mat_iq4_kt_F16_T(int n, const void * vx, size_t bx, const DataIn
|
|||||||
if constexpr (nrc_y == 1) {
|
if constexpr (nrc_y == 1) {
|
||||||
auto sum16 = vaddq_f16(accd[0], accd[1]);
|
auto sum16 = vaddq_f16(accd[0], accd[1]);
|
||||||
auto sum = vaddq_f32(vcvt_f32_f16(vget_low_f16(sum16)), vcvt_f32_f16(vget_high_f16(sum16)));
|
auto sum = vaddq_f32(vcvt_f32_f16(vget_low_f16(sum16)), vcvt_f32_f16(vget_high_f16(sum16)));
|
||||||
info.store(ix, 0, d*vaddvq_f32(sum) + dav*row_sum[0]);
|
info.store(ix, 0, vaddvq_f32(sum) + dav*row_sum[0]);
|
||||||
} else {
|
} else {
|
||||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||||
auto sum = vaddq_f32(vcvt_f32_f16(vget_low_f16(accd[iy])), vcvt_f32_f16(vget_high_f16(accd[iy])));
|
auto sum = vaddq_f32(vcvt_f32_f16(vget_low_f16(accd[iy])), vcvt_f32_f16(vget_high_f16(accd[iy])));
|
||||||
info.store(ix, iy, d*vaddvq_f32(sum) + dav*row_sum[iy]);
|
info.store(ix, iy, vaddvq_f32(sum) + dav*row_sum[iy]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int nrc_y>
|
template <int nrc_y>
|
||||||
static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||||
assert(n%QK_K == 0);
|
assert(n%QK_K == 0);
|
||||||
const int nb = n/QK_K;
|
const int nb = n/QK_K;
|
||||||
constexpr int kNumGroups = 64;
|
constexpr int kNumGroups = 64;
|
||||||
@@ -693,11 +996,11 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
|
|||||||
|
|
||||||
bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
|
bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
|
||||||
|
|
||||||
if (ne00%QK_K == 0 && ggml_type(typeB) == GGML_TYPE_F32 && ggml_type(typeA) == GGML_TYPE_IQ4_KT) {
|
//if (ne00%QK_K == 0 && ggml_type(typeB) == GGML_TYPE_F32 && ggml_type(typeA) == GGML_TYPE_IQ4_KT) {
|
||||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels);
|
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels);
|
||||||
func16 = nullptr;
|
// func16 = nullptr;
|
||||||
return true;
|
// return true;
|
||||||
}
|
//}
|
||||||
|
|
||||||
if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F16) {
|
if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F16) {
|
||||||
return false;
|
return false;
|
||||||
@@ -722,6 +1025,17 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, size_t stride_y, int nrc_x) {
|
||||||
|
switch (type) {
|
||||||
|
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
|
||||||
|
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
|
||||||
|
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
|
||||||
|
default: return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -8,4 +8,6 @@
|
|||||||
|
|
||||||
bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
|
bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
|
||||||
|
|
||||||
|
bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * vy, size_t stride_y, int nrc_x);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -233,6 +233,24 @@ struct MulMat {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny);
|
static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny);
|
||||||
|
static inline ggml_type is_dequant_better(ggml_type type, int nrc_y) {
|
||||||
|
#ifdef __AVX2__
|
||||||
|
switch (type) {
|
||||||
|
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
|
||||||
|
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
|
||||||
|
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
|
||||||
|
default: break;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
switch (type) {
|
||||||
|
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
|
||||||
|
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
|
||||||
|
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
|
||||||
|
default: break;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return type;
|
||||||
|
}
|
||||||
static inline int num_rows(ggml_type type) {
|
static inline int num_rows(ggml_type type) {
|
||||||
#ifdef HAVE_FANCY_SIMD
|
#ifdef HAVE_FANCY_SIMD
|
||||||
switch (type) {
|
switch (type) {
|
||||||
@@ -312,6 +330,49 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
|
|||||||
float * C, long stride_C, int ith, int nth) {
|
float * C, long stride_C, int ith, int nth) {
|
||||||
|
|
||||||
MulMat mm;
|
MulMat mm;
|
||||||
|
|
||||||
|
auto etypeA = ggml_type(typeA);
|
||||||
|
if (auto dequant_type = MulMat::is_dequant_better(etypeA, Ny); dequant_type != etypeA) {
|
||||||
|
if (!MulMat::prepare(dequant_type, typeB, ne00, mm, Ny)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int k_x_step = 32;
|
||||||
|
|
||||||
|
auto num_rows = MulMat::num_rows(ggml_type(dequant_type));
|
||||||
|
GGML_ASSERT(Nx%num_rows == 0);
|
||||||
|
auto nrc_x = (Nx/num_rows + nth - 1)/nth;
|
||||||
|
auto first_x = ith*nrc_x;
|
||||||
|
if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x;
|
||||||
|
first_x *= num_rows;
|
||||||
|
nrc_x *= num_rows;
|
||||||
|
|
||||||
|
auto type_size = ggml_type_size(dequant_type);
|
||||||
|
|
||||||
|
thread_local std::vector<char> f;
|
||||||
|
|
||||||
|
size_t row_size_qx = ne00*type_size;
|
||||||
|
size_t row_size_qy = strideB;
|
||||||
|
|
||||||
|
//printf("Dequant mul mat %s x %s: ne00 = %d, row_size = %d\n", ggml_type_name(dequant_type), ggml_type_name(ggml_type(typeB)), (int)ne00, (int)row_size_qx);
|
||||||
|
|
||||||
|
DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};
|
||||||
|
|
||||||
|
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
|
||||||
|
auto this_info = info;
|
||||||
|
this_info.s += ix;
|
||||||
|
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
||||||
|
if (f.size() < row_size_qx*this_nrc_x) f.resize(row_size_qx*this_nrc_x);
|
||||||
|
if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f.data(), ne00, this_nrc_x)) {
|
||||||
|
GGML_ABORT("Fatal error");
|
||||||
|
}
|
||||||
|
mm.mul_mat_NxM(ne00, f.data(), row_size_qx, this_info, this_nrc_x, Ny);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
|
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,8 @@
|
|||||||
#include "iqk_quantize.h"
|
#include "iqk_quantize.h"
|
||||||
#include "iqk_config.h"
|
#include "iqk_config.h"
|
||||||
|
|
||||||
|
#include "iqk_gemm_ktquants.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
@@ -8241,6 +8243,9 @@ size_t quantize_iq2_kt(const float * src, void * dst, int64_t nrows, int64_t n_p
|
|||||||
|
|
||||||
void dequantize_row_iq2_kt(const block_iq2_kt * x, float * y, int64_t k) {
|
void dequantize_row_iq2_kt(const block_iq2_kt * x, float * y, int64_t k) {
|
||||||
assert(k % QuantizerIQ2KT::kSuperBlockSize == 0);
|
assert(k % QuantizerIQ2KT::kSuperBlockSize == 0);
|
||||||
|
#ifdef __AVX2__
|
||||||
|
if (iqk_dequantize_ktquants(GGML_TYPE_IQ2_KT, k, x, 0, y, 0, 1)) return;
|
||||||
|
#endif
|
||||||
const int nb = k / QuantizerIQ2KT::kSuperBlockSize;
|
const int nb = k / QuantizerIQ2KT::kSuperBlockSize;
|
||||||
const float * dptr = (const float *)x;
|
const float * dptr = (const float *)x;
|
||||||
const float d = *dptr * QuantizerIQ2KT::kScale;
|
const float d = *dptr * QuantizerIQ2KT::kScale;
|
||||||
@@ -8494,6 +8499,9 @@ size_t quantize_iq3_kt(const float * src, void * dst, int64_t nrows, int64_t n_p
|
|||||||
}
|
}
|
||||||
|
|
||||||
void dequantize_row_iq3_kt(const block_iq3_kt * x, float * y, int64_t k) {
|
void dequantize_row_iq3_kt(const block_iq3_kt * x, float * y, int64_t k) {
|
||||||
|
#ifdef __AVX2__
|
||||||
|
if (iqk_dequantize_ktquants(GGML_TYPE_IQ3_KT, k, x, 0, y, 0, 1)) return;
|
||||||
|
#endif
|
||||||
using Q = QuantizerIQ3KT;
|
using Q = QuantizerIQ3KT;
|
||||||
constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
|
constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
|
||||||
assert(k % Q::kSuperBlockSize == 0);
|
assert(k % Q::kSuperBlockSize == 0);
|
||||||
@@ -8750,6 +8758,9 @@ size_t quantize_iq4_kt(const float * src, void * dst, int64_t nrows, int64_t n_p
|
|||||||
}
|
}
|
||||||
|
|
||||||
void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) {
|
void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) {
|
||||||
|
#ifdef __AVX2__
|
||||||
|
if (iqk_dequantize_ktquants(GGML_TYPE_IQ4_KT, k, x, 0, y, 0, 1)) return;
|
||||||
|
#endif
|
||||||
using Q = QuantizerIQ4KT;
|
using Q = QuantizerIQ4KT;
|
||||||
assert(k % Q::kSuperBlockSize == 0);
|
assert(k % Q::kSuperBlockSize == 0);
|
||||||
constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
|
constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
|
||||||
|
|||||||
Reference in New Issue
Block a user