mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-12 06:50:08 +00:00
Experimenting with dequant + f32 GEMM
iq2_kt: from PP512 = 57.3 t/s to PP512 = 135.0 t/s iq3_kt: from PP512 = 43.8 t/s to PP512 = 131.4 t/s
This commit is contained in:
@@ -97,6 +97,45 @@ struct Trellis2 {
|
||||
}
|
||||
};
|
||||
|
||||
void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
|
||||
Trellis1 trellis;
|
||||
|
||||
auto shifts = _mm_set_epi32(0, 0, 4, 0);
|
||||
auto values = _mm_loadu_si128((const __m128i *)iq4k_values);
|
||||
|
||||
union { __m256 vec; float val[8]; } s_helper;
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
auto d = _mm256_set1_ps(*dptr * 31.75f * 1.05f);
|
||||
const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const uint16_t * ql = (const uint16_t *)x[i].ql;
|
||||
auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales);
|
||||
s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
|
||||
s8 = _mm_shuffle_epi8(values, s8);
|
||||
auto s32 = _mm256_cvtepi8_epi32(s8);
|
||||
s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32));
|
||||
for (int ib = 0; ib < QK_K/64; ++ib) {
|
||||
auto scale1 = _mm256_set1_ps(s_helper.val[2*ib+0]);
|
||||
auto scale2 = _mm256_set1_ps(s_helper.val[2*ib+1]);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
auto xval1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(ql[8*ib+j+0]+4096)));
|
||||
auto xval2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(ql[8*ib+j+4]+4096)));
|
||||
_mm256_storeu_ps(y + i*QK_K + 64*ib + 8*j + 0, xval1);
|
||||
_mm256_storeu_ps(y + i*QK_K + 64*ib + 8*j + 32, xval2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
y += stride_y;
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
@@ -165,6 +204,55 @@ static inline __m256 abs_ps(__m256 vals) {
|
||||
return _mm256_andnot_ps(sign_bit, vals);
|
||||
}
|
||||
|
||||
void iqk_dequantize_iq3_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) {
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
|
||||
Trellis1 trellis;
|
||||
|
||||
union { __m256 vec; float val[8]; } s_helper;
|
||||
|
||||
auto shifts = _mm_set_epi32(0, 0, 4, 0);
|
||||
|
||||
__m256i all_signs[4];
|
||||
auto mask1 = _mm256_set1_epi32(0x01);
|
||||
auto mask2 = _mm256_set1_epi32(0x10);
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
auto d = _mm256_set1_ps(*dptr * 31.75f * 1.015f);
|
||||
const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const uint16_t * ql = (const uint16_t *)x[i].ql;
|
||||
const uint8_t * qh = x[i].qh;
|
||||
auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales);
|
||||
s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
|
||||
auto s32 = _mm256_cvtepi8_epi32(s8);
|
||||
s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32));
|
||||
for (int j = 0; j < 4; ++j) all_signs[j] = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qh + 8*j)));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]);
|
||||
auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
uint32_t val1 = ql[4*ib+j ] + 4096;
|
||||
uint32_t val2 = ql[4*ib+j+16] + 4096;
|
||||
auto sign1 = _mm256_and_si256(_mm256_cmpeq_epi32(_mm256_and_si256(all_signs[j], mask1), mask1), _mm256_set1_epi32(0x80000000));
|
||||
auto sign2 = _mm256_and_si256(_mm256_cmpeq_epi32(_mm256_and_si256(all_signs[j], mask2), mask2), _mm256_set1_epi32(0x80000000));
|
||||
all_signs[j] = _mm256_srli_epi32(all_signs[j], 1);
|
||||
auto x_val1 = abs_ps(trellis_gen8(trellis.next8(val1)));
|
||||
auto x_val2 = abs_ps(trellis_gen8(trellis.next8(val2)));
|
||||
x_val1 = _mm256_mul_ps(scale1, _mm256_xor_ps(x_val1, _mm256_castsi256_ps(sign1)));
|
||||
x_val2 = _mm256_mul_ps(scale2, _mm256_xor_ps(x_val2, _mm256_castsi256_ps(sign2)));
|
||||
_mm256_storeu_ps(y + i*QK_K+32*ib+8*j , x_val1);
|
||||
_mm256_storeu_ps(y + i*QK_K+32*ib+8*j+128, x_val2);
|
||||
}
|
||||
}
|
||||
}
|
||||
y += stride_y;
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
@@ -227,6 +315,55 @@ static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
|
||||
}
|
||||
}
|
||||
|
||||
void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) {
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
constexpr int kNumGroups = 64;
|
||||
|
||||
Trellis2 trellis;
|
||||
|
||||
union { __m256 vec; float val[8]; } s_helper;
|
||||
union { __m256i vec; uint32_t val[8]; } o_helper;
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f);
|
||||
auto dav = _mm256_set1_ps(dptr[1]);
|
||||
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs);
|
||||
const uint32_t * shb = x[i].qs;
|
||||
const uint8_t * ql = (const uint8_t *)(shb + 8);
|
||||
const uint8_t * qh = ql + kNumGroups;
|
||||
auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1);
|
||||
s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64))));
|
||||
o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]);
|
||||
auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
|
||||
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
|
||||
uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
|
||||
uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
|
||||
uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
|
||||
uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
|
||||
auto x_val1 = _mm256_fmadd_ps(scale1, trellis_gen8(trellis.next8(val1, val3)), dav);
|
||||
auto x_val2 = _mm256_fmadd_ps(scale2, trellis_gen8(trellis.next8(val2, val4)), dav);
|
||||
|
||||
_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j, x_val1);
|
||||
_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j + QK_K/2, x_val2);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
y += stride_y;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
@@ -333,53 +470,14 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
|
||||
|
||||
}
|
||||
|
||||
void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) {
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
constexpr int kNumGroups = 64;
|
||||
|
||||
Trellis2 trellis;
|
||||
|
||||
union { __m256 vec; float val[8]; } s_helper;
|
||||
union { __m256i vec; uint32_t val[8]; } o_helper;
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f);
|
||||
auto dav = _mm256_set1_ps(dptr[1]);
|
||||
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs);
|
||||
const uint32_t * shb = x[i].qs;
|
||||
const uint8_t * ql = (const uint8_t *)(shb + 8);
|
||||
const uint8_t * qh = ql + kNumGroups;
|
||||
auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1);
|
||||
s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64))));
|
||||
o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]);
|
||||
auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
|
||||
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
|
||||
uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
|
||||
uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
|
||||
uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
|
||||
uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
|
||||
auto x_val1 = _mm256_fmadd_ps(scale1, trellis_gen8(trellis.next8(val1, val3)), dav);
|
||||
auto x_val2 = _mm256_fmadd_ps(scale2, trellis_gen8(trellis.next8(val2, val4)), dav);
|
||||
|
||||
_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j, x_val1);
|
||||
_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j + QK_K/2, x_val2);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
y += stride_y;
|
||||
|
||||
bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, y, stride_y, nrc_x); break;
|
||||
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, y, stride_y, nrc_x); break;
|
||||
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, y, stride_y, nrc_x); break;
|
||||
default: return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#else // !__x86_64__
|
||||
@@ -771,6 +869,10 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
|
||||
return true;
|
||||
}
|
||||
|
||||
bool iqk_dequantize_ktquants([[maybe_unused]] int type, [[maybe_unused]] int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] float * y, [[maybe_unused]] size_t stride_y, [[maybe_unused]] int nrc_x) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
@@ -8,6 +8,6 @@
|
||||
|
||||
bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
|
||||
|
||||
void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x);
|
||||
bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -236,6 +236,8 @@ struct MulMat {
|
||||
static inline bool is_dequant_better(ggml_type type, int nrc_y) {
|
||||
#ifdef __AVX2__
|
||||
switch (type) {
|
||||
case GGML_TYPE_IQ2_KT: return nrc_y >= 32;
|
||||
case GGML_TYPE_IQ3_KT: return nrc_y >= 32;
|
||||
case GGML_TYPE_IQ4_KT: return nrc_y >= 32;
|
||||
default: break;
|
||||
}
|
||||
@@ -349,22 +351,12 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
|
||||
this_info.s += ix;
|
||||
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
||||
if (f32.size() < std::vector<float>::size_type(ne00*this_nrc_x)) f32.resize(ne00*this_nrc_x);
|
||||
iqk_dequantize_iq4_kt(ne00, (const char *)A + (first_x + ix)*strideA, strideA, f32.data(), ne00, this_nrc_x);
|
||||
if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f32.data(), ne00, this_nrc_x)) {
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
mm.mul_mat_NxM(ne00, (const char *)f32.data(), row_size_qx, this_info, this_nrc_x, Ny);
|
||||
}
|
||||
|
||||
//thread_local std::vector<float> f32;
|
||||
//if (f32.size() < std::vector<float>::size_type(ne00*nrc_x)) f32.resize(ne00*nrc_x);
|
||||
|
||||
//iqk_dequantize_iq4_kt(ne00, (const char *)A + first_x*strideA, strideA, f32.data(), ne00, nrc_x);
|
||||
|
||||
//size_t row_size_qx = ne00*sizeof(float);
|
||||
//size_t row_size_qy = strideB;
|
||||
|
||||
//DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};
|
||||
|
||||
//mm.mul_mat_NxM(ne00, (const char *)f32.data(), row_size_qx, info, nrc_x, Ny);
|
||||
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
@@ -8243,6 +8243,9 @@ size_t quantize_iq2_kt(const float * src, void * dst, int64_t nrows, int64_t n_p
|
||||
|
||||
void dequantize_row_iq2_kt(const block_iq2_kt * x, float * y, int64_t k) {
|
||||
assert(k % QuantizerIQ2KT::kSuperBlockSize == 0);
|
||||
#ifdef __AVX2__
|
||||
if (iqk_dequantize_ktquants(GGML_TYPE_IQ2_KT, k, x, 0, y, 0, 1)) return;
|
||||
#endif
|
||||
const int nb = k / QuantizerIQ2KT::kSuperBlockSize;
|
||||
const float * dptr = (const float *)x;
|
||||
const float d = *dptr * QuantizerIQ2KT::kScale;
|
||||
@@ -8496,6 +8499,9 @@ size_t quantize_iq3_kt(const float * src, void * dst, int64_t nrows, int64_t n_p
|
||||
}
|
||||
|
||||
void dequantize_row_iq3_kt(const block_iq3_kt * x, float * y, int64_t k) {
|
||||
#ifdef __AVX2__
|
||||
if (iqk_dequantize_ktquants(GGML_TYPE_IQ3_KT, k, x, 0, y, 0, 1)) return;
|
||||
#endif
|
||||
using Q = QuantizerIQ3KT;
|
||||
constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
|
||||
assert(k % Q::kSuperBlockSize == 0);
|
||||
@@ -8753,8 +8759,8 @@ size_t quantize_iq4_kt(const float * src, void * dst, int64_t nrows, int64_t n_p
|
||||
|
||||
void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) {
|
||||
#ifdef __AVX2__
|
||||
iqk_dequantize_iq4_kt(k, x, 0, y, 0, 1);
|
||||
#else
|
||||
if (iqk_dequantize_ktquants(GGML_TYPE_IQ4_KT, k, x, 0, y, 0, 1)) return;
|
||||
#endif
|
||||
using Q = QuantizerIQ4KT;
|
||||
assert(k % Q::kSuperBlockSize == 0);
|
||||
constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
|
||||
@@ -8782,7 +8788,6 @@ void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) {
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void vec_dot_iq4_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
|
||||
|
||||
Reference in New Issue
Block a user