iq4_k: Rearrange blocks for faster matrix multiplications

On Zen4 we get PP-512(LLaMA-3.1-8B) = 216.7 t/s.
In comparison, the original bit arrangement gave 180 t/s.
The trick is to have quants
  0...3,  64...67,  128...131, 192...195 in block 0,
  4...7,  68...71,  131...135, 196...199 in block 2, etc.
With that, we can simply sum the integer dot products
and multiply with the block scales, whithout needing
to shuffle scales/dot products and such.

iq4_k is now the fastest quantization type on Zen4, so
time to see how this will work on the other platforms.
This commit is contained in:
Iwan Kawrakow
2024-11-04 10:22:59 +02:00
parent 52874c5d21
commit 48974c7acd
5 changed files with 290 additions and 80 deletions

View File

@@ -406,6 +406,7 @@ extern "C" {
GGML_TYPE_IQ4_KS = 144,
GGML_TYPE_IQ2_KS = 145,
GGML_TYPE_IQ4_KSS = 146,
GGML_TYPE_Q8_K16 = 147,
GGML_TYPE_COUNT,
};

View File

@@ -1103,6 +1103,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q8_K64,
.row_meta_size = 0,
},
[GGML_TYPE_Q8_K16] = {
.type_name = "q8_K16",
.blck_size = QK_K,
.type_size = sizeof(block_q8_K),
.is_quantized = true,
.from_float = iqk_quantize_row_q8_K16,
.row_meta_size = 0,
},
[GGML_TYPE_BF16] = {
.type_name = "bf16",
.blck_size = 1,
@@ -1215,7 +1223,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq4_k,
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_k_ref,
.vec_dot = vec_dot_iq4_k_q8_k,
.vec_dot_type = GGML_TYPE_Q8_K,
.vec_dot_type = GGML_TYPE_Q8_K16,
.nrows = 1,
.row_meta_size = 0,
},
@@ -15456,6 +15464,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q8_K:
case GGML_TYPE_Q8_K16:
case GGML_TYPE_Q8_K64:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:

View File

@@ -988,45 +988,89 @@ struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
const __m128i hshuff = _mm_loadu_si128((const __m128i*)k_shuff);
constexpr static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
};
//struct IQXKScales2 {
// IQXKScales2(uint8_t shift, int8_t min_val) : eshift(_mm256_set1_epi16(shift)), min(_mm256_set1_epi16(min_val)) {}
// template <typename Q8>
// inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m512i * scales) const {
// process(i, d, extra, _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle)), q8, accm, scales);
// }
// template <typename Q8>
// inline void process(int i, float d, uint16_t extra, __m256i scales16, const Q8& q8, __m256 * accm, __m512i * scales) const {
// auto scales_s = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, extra, min, eshift));
// for (int iy = 0; iy < Q8::nrc_y; ++iy) {
// const __m256i prod = _mm256_madd_epi16(scales_s, q8.load_bsums(iy, i));
// accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
// }
// auto aux_1 = MM256_SET_M128I(_mm256_castsi256_si128(scales16), _mm256_castsi256_si128(scales16));
// auto aux_2 = MM256_SET_M128I(_mm256_extracti128_si256(scales16, 1), _mm256_extracti128_si256(scales16, 1));
// auto scales16_1 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_1), aux_1, 1);
// auto scales16_2 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_2), aux_2, 1);
// scales[0] = _mm512_shuffle_epi8(scales16_1, shuffles[0]);
// scales[1] = _mm512_shuffle_epi8(scales16_1, shuffles[1]);
// scales[2] = _mm512_shuffle_epi8(scales16_2, shuffles[0]);
// scales[3] = _mm512_shuffle_epi8(scales16_2, shuffles[1]);
// }
// const __m256i eshift;
// const __m256i min;
// const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
// const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);
// const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200);
// const __m512i shuffles[2] = {
// _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(),
// _mm_set1_epi16(0x0100), 0), _mm_set1_epi16(0x0302), 1), _mm_set1_epi16(0x0504), 2), _mm_set1_epi16(0x0706), 3),
// _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(),
// _mm_set1_epi16(0x0908), 0), _mm_set1_epi16(0x0b0a), 1), _mm_set1_epi16(0x0d0c), 2), _mm_set1_epi16(0x0f0e), 3)
// };
//};
// d4*d8*sum[y_i*s_i*(x_i + b_i) = d4*d8*(sum[s_i*y_i*x_i] + sum[y_i*s_i*b_i] =
// d4*d8*sum[s_k*(xy_k + by_k)]
struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -128), values(load_iq4nl_values_512()) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
inline __m512 new_block(int i, __m512& scales_s) {
d = GGML_FP16_TO_FP32(x[i].d);
auto scales8 = make_scales(x[i].scales_l, x[i].scales_h);
scales_s = _mm512_mask_blend_ps(x[i].extra, min1, min2);
prepare(x[i].qs);
iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales);
return _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(scales8));
//auto scales16 = _mm256_cvtepi8_epi16(scales8);
//scales_s = _mm256_mask_blend_epi16(x[i].extra, min1, min2);
////scales_s = _mm256_mullo_epi16(scales16, _mm256_mask_blend_epi16(x[i].extra, min1, min2));
//prepare(x[i].qs);
//return _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(scales16));
}
inline void prepare(const uint8_t * q4) {
bits.prepare64(q4);
// We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111
// bits.valuse[1]: 16..31, 48...63, 80...95, 112..127
// etc.
auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));
bits.values[0] = _mm512_shuffle_epi8(values, tmp);
tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));
bits.values[2] = _mm512_shuffle_epi8(values, tmp);
bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]);
bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]);
bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]);
bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]);
}
__m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const {
__m128i make_scales(const uint8_t * scales_l, const uint8_t * scales_h) const {
//uint8_t Ls[QK_K/16];
//for (int j = 0; j < QK_K/32; ++j) {
// const uint8_t sh = scales_h[j/2] >> 4*(j%2);
// Ls[2*j+0] = ((scales_l[j] & 0xf) | ((sh << 4) & 0x30)) - 32;
// Ls[2*j+1] = ((scales_l[j] >> 4) | ((sh << 2) & 0x30)) - 32;
//}
//return _mm_loadu_si128((const __m128i *)Ls);
uint64_t aux64;
memcpy(&aux64, scales_l, 8);
uint32_t aux32 = *(const uint32_t *)scales_h;
auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl);
const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16);
auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh);
auto sch = _mm_shuffle_epi8(aux, iqxk.scale_shuffle);
return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
auto sch = _mm_shuffle_epi8(aux, scale_shuffle);
return _mm_shuffle_epi8(_mm_add_epi8(_mm_or_si128(scl, sch), m32), scale_shuffle);
}
Q4Bits bits;
const IQXKScales2 iqxk;
const __m512i values;
const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
const __m128i maskl = _mm_set1_epi8(0xf);
const __m128i maskh = _mm_set1_epi8(0x30);
const __m128i m32 = _mm_set1_epi8(-32);
const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
const __m512 min1 = _mm512_set1_ps(-128.f);
const __m512 min2 = _mm512_set1_ps(-124.f);
};
struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
@@ -1370,6 +1414,47 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D
}
}
template <typename Dequantizer, int nrc_y>
static void mul_mat_iqX_k_q8_K16_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
Q8<nrc_y> q8(info);
Dequantizer deq(vx, bx);
__m512 accd[nrc_y];
for (int ix = 0; ix < nrc_x; ++ix) {
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
deq.new_row(ix);
for (int i = 0; i < nb; ++i) {
__m512 shifts;
auto scales = deq.new_block(i, shifts);
for (int iy = 0; iy < nrc_y; ++iy) {
auto vd = _mm512_mul_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), scales);
auto bs = _mm512_mul_ps(shifts, _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(q8.load_bsums(iy, i))));
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, deq.bits.values[0], q8.load_quants64(iy, i, 0));
sumi = _mm512_dpbusd_epi32(sumi, deq.bits.values[1], q8.load_quants64(iy, i, 1));
sumi = _mm512_dpbusd_epi32(sumi, deq.bits.values[2], q8.load_quants64(iy, i, 2));
sumi = _mm512_dpbusd_epi32(sumi, deq.bits.values[3], q8.load_quants64(iy, i, 3));
accd[iy] = _mm512_fmadd_ps(vd, _mm512_add_ps(bs, _mm512_cvtepi32_ps(sumi)), accd[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, _mm512_reduce_add_ps(accd[iy]));
}
}
}
template <typename Dequantizer>
static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
@@ -3774,7 +3859,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
#ifdef HAVE_FANCY_SIMD
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ6K> ||
std::is_same_v<Dequantizer, DequantizerIQ5K> ||
std::is_same_v<Dequantizer, DequantizerIQ4K> ||
//std::is_same_v<Dequantizer, DequantizerIQ4K> ||
std::is_same_v<Dequantizer, DequantizerIQ3K> ||
std::is_same_v<Dequantizer, DequantizerIQ4XS>||
std::is_same_v<Dequantizer, DequantizerIQ4KS>||
@@ -3936,7 +4021,15 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
break;
case GGML_TYPE_IQ4_K:
assert (ne00 % QK_K == 0);
MulMat::set_functions<DequantizerIQ4K>(mm);
mm.funcs[0] = mul_mat_iqX_k_q8_K16_AVX512<DequantizerIQ4K, 1>;
mm.funcs[1] = mul_mat_iqX_k_q8_K16_AVX512<DequantizerIQ4K, 2>;
mm.funcs[2] = mul_mat_iqX_k_q8_K16_AVX512<DequantizerIQ4K, 3>;
mm.funcs[3] = mul_mat_iqX_k_q8_K16_AVX512<DequantizerIQ4K, 4>;
mm.funcs[4] = mul_mat_iqX_k_q8_K16_AVX512<DequantizerIQ4K, 5>;
mm.funcs[5] = mul_mat_iqX_k_q8_K16_AVX512<DequantizerIQ4K, 6>;
mm.funcs[6] = mul_mat_iqX_k_q8_K16_AVX512<DequantizerIQ4K, 7>;
mm.funcs[7] = mul_mat_iqX_k_q8_K16_AVX512<DequantizerIQ4K, 8>;
expected_typeB = GGML_TYPE_Q8_K16;
break;
case GGML_TYPE_IQ5_K:
assert (ne00 % QK_K == 0);

View File

@@ -1358,28 +1358,31 @@ void dequantize_row_iq4_k(const block_iq4_k * x, float * y, int64_t k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
int8_t Ls[QK_K/16];
for (int i = 0; i < nb; i++) {
const uint8_t * qs = x[i].qs;
const float d = GGML_FP16_TO_FP32(x[i].d);
uint16_t extra = x[i].extra;
for (int ib = 0; ib < QK_K/32; ++ib) {
const uint8_t sh = x[i].scales_h[ib/2] >> 4*(ib%2);
const float dl1 = d * (((x[i].scales_l[ib] & 0xf) | ((sh << 4) & 0x30)) - 32);
const float dl2 = d * (((x[i].scales_l[ib] >> 4) | ((sh << 2) & 0x30)) - 32);
const int8_t * values1 = extra & 1 ? iq4k_values + 16 : iq4k_values;
const int8_t * values2 = extra & 2 ? iq4k_values + 16 : iq4k_values;
extra >>= 2;
for (int j = 0; j < 16; ++j) {
y[j+ 0] = dl1 * values1[qs[j] & 0xf];
y[j+16] = dl2 * values2[qs[j] >> 4];
}
y += 32;
qs += 16;
for (int j = 0; j < QK_K/32; ++j) {
const uint8_t sh = x[i].scales_h[j/2] >> 4*(j%2);
Ls[2*j+0] = ((x[i].scales_l[j] & 0xf) | ((sh << 4) & 0x30)) - 32;
Ls[2*j+1] = ((x[i].scales_l[j] >> 4) | ((sh << 2) & 0x30)) - 32;
}
for (int ib = 0; ib < QK_K/16; ++ib) {
const int8_t * values = extra & 1 ? iq4k_values + 16 : iq4k_values;
extra >>= 1;
const float dl = d*Ls[ib];
for (int j = 0; j < 4; ++j) {
y[4*ib + j + 0] = dl*values[qs[4*ib + j] & 0xf];
y[4*ib + j + 64] = dl*values[qs[4*ib + j] >> 4];
y[4*ib + j + 128] = dl*values[qs[4*ib + j + 64] & 0xf];
y[4*ib + j + 192] = dl*values[qs[4*ib + j + 64] >> 4];
}
}
y += QK_K;
}
}
@@ -1402,31 +1405,34 @@ void vec_dot_iq4_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx,
const block_iq4_k * x = (const block_iq4_k *)vx;
const block_q8_K * y = (const block_q8_K *)vy;
int8_t Ls[16];
float sumf = 0;
for (int ibl = 0; ibl < nb; ++ibl) {
const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
uint16_t extra = x[ibl].extra;
uint32_t h = *((const uint32_t *)x[ibl].scales_h);
for (int ib = 0; ib < QK_K/32; ++ib) {
Ls[2*ib+0] = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
Ls[2*ib+1] = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
h >>= 4;
}
const uint8_t * qs = x[ibl].qs;
const int8_t * q8 = y[ibl].qs;
int32_t sum = 0;
for (int ib = 0; ib < QK_K/32; ++ib) {
const int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
const int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
h >>= 4;
const int8_t * values1 = iq4k_values + 16*(extra & 1);
const int8_t * values2 = iq4k_values + 8*(extra & 2);
extra >>= 2;
int sumi1 = 0, sumi2 = 0;
for (int j = 0; j < 16; ++j) {
sumi1 += q8[j+ 0] * values1[qs[j] & 0xf];
sumi2 += q8[j+16] * values2[qs[j] >> 4];
int bsum = 0;
for (int ib = 0; ib < QK_K/16; ++ib) {
const int8_t * values = extra & 1 ? iq4k_values + 16 : iq4k_values;
extra >>= 1;
int sumi = 0;
for (int j = 0; j < 4; ++j) {
sumi += q8[4*ib + j + 0] * values[qs[4*ib + j] & 0xf];
sumi += q8[4*ib + j + 64] * values[qs[4*ib + j] >> 4];
sumi += q8[4*ib + j + 128] * values[qs[4*ib + j + 64] & 0xf];
sumi += q8[4*ib + j + 192] * values[qs[4*ib + j + 64] >> 4];
}
sum += ls1*sumi1 + ls2*sumi2;
qs += 16;
q8 += 32;
bsum += sumi*Ls[ib];
}
sumf += d4d8 * sum;
sumf += d4d8*bsum;
}
*s = sumf;
@@ -1452,7 +1458,7 @@ inline int best_index_iq4nl(const int8_t * values, float x) {
static void quantize_row_iq4_k_impl_bs16(const int super_block_size, const int block_size, const float * x,
block_iq4_k * y,
float * scales, float * weight, uint8_t * L,
float * scales, float * weight, float * xb, uint8_t * L,
const int8_t * values,
const float * quant_weights,
const int ntry) {
@@ -1473,10 +1479,21 @@ static void quantize_row_iq4_k_impl_bs16(const int super_block_size, const int b
float max_scale = 0, amax_scale = 0;
uint16_t extra = 0;
for (int ib = 0; ib < super_block_size/block_size; ++ib) {
const float * xb = x + ib*block_size;
const float * x4 = x + 4*ib;
for (int j = 0; j < 4; ++j) {
xb[j+ 0] = x4[j + 0];
xb[j+ 4] = x4[j + 64];
xb[j+ 8] = x4[j + 128];
xb[j+12] = x4[j + 192];
}
if (quant_weights) {
const float * qw = quant_weights + ib*block_size;
for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
const float * qw = quant_weights + 4*ib;
for (int j = 0; j < 4; ++j) {
weight[j+ 0] = qw[j+ 0] * sqrtf(sigma2 + xb[j+ 0]*xb[j+ 0]);
weight[j+ 4] = qw[j+ 64] * sqrtf(sigma2 + xb[j+ 4]*xb[j+ 4]);
weight[j+ 8] = qw[j+128] * sqrtf(sigma2 + xb[j+ 8]*xb[j+ 8]);
weight[j+12] = qw[j+192] * sqrtf(sigma2 + xb[j+12]*xb[j+12]);
}
} else {
for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
}
@@ -1569,27 +1586,11 @@ static void quantize_row_iq4_k_impl_bs16(const int super_block_size, const int b
y->extra = extra;
float id = d ? 1/d : 0.f;
float sumqx = 0, sumq2 = 0;
int8_t Ls[16];
for (int ib = 0; ib < super_block_size/block_size; ++ib) {
const int8_t * block_values = extra & (1 << ib) ? shifted_values : values;
int l = nearest_int(id*scales[ib]);
l = MAX(-32, MIN(31, l));
float dl = d * l;
float idl = dl ? 1/dl : 0.f;
uint8_t * Lb = L + ib*block_size;
const float * xb = x + ib*block_size;
if (quant_weights) {
const float * qw = quant_weights + ib*block_size;
for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
} else {
for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
}
for (int j = 0; j < block_size; ++j) {
Lb[j] = best_index_iq4nl(block_values, idl*xb[j]);
float w = weight[j];
float q = block_values[Lb[j]]*l;
sumqx += w*q*xb[j];
sumq2 += w*q*q;
}
Ls[ib] = l;
l += 32;
uint8_t l_l = l & 0xf;
uint8_t l_h = l >> 4;
@@ -1597,11 +1598,23 @@ static void quantize_row_iq4_k_impl_bs16(const int super_block_size, const int b
else y->scales_l[ib/2] |= (l_l << 4);
scales_h[ib/8] |= (l_h << 2*(ib%8));
}
for (int j = 0; j < super_block_size; ++j) {
int j4 = j/4;
int ib = j4%16;
float dl = d * Ls[ib];
float idl = dl ? 1/dl : 0.f;
const int8_t * block_values = extra & (1 << ib) ? shifted_values : values;
L[j] = best_index_iq4nl(block_values, idl*x[j]);
float w = quant_weights ? quant_weights[j]*sqrtf(sigma2 + x[j]*x[j]) : x[j]*x[j];
float q = block_values[L[j]]*Ls[ib];
sumqx += w*q*x[j];
sumq2 += w*q*q;
}
if (sumq2 > 0) y->d = GGML_FP32_TO_FP16(sumqx/sumq2);
for (int i = 0; i < super_block_size/32; ++i) {
for (int j = 0; j < 16; ++j) {
y->qs[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
for (int i = 0; i < super_block_size/128; ++i) {
for (int j = 0; j < 64; ++j) {
y->qs[64*i + j] = L[128*i + j] | (L[128*i + j + 64] << 4);
}
}
}
@@ -1625,13 +1638,14 @@ size_t quantize_iq4_k(const float * src, void * dst, int64_t nrows, int64_t n_pe
char * qrow = (char *)dst;
uint8_t L[QK_K];
float weight[16];
float xb[16];
float scales[QK_K/16];
for (int64_t row = 0; row < nrows; ++row) {
block_iq4_k * iq4 = (block_iq4_k *)qrow;
for (int ibl = 0; ibl < nblock; ++ibl) {
const float * qw = imatrix ? imatrix + QK_K*ibl : NULL;
quantize_row_iq4_k_impl_bs16(QK_K, 16, src + QK_K*ibl, iq4 + ibl,
scales, weight, L, iq4k_values, qw, 7);
scales, weight, xb, L, iq4k_values, qw, 7);
}
src += n_per_row;
qrow += nblock*sizeof(block_iq4_k);
@@ -2440,6 +2454,98 @@ void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) {
}
void iqk_quantize_row_q8_K16(const float * x, void * vy, int64_t k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
block_q8_K * y = (block_q8_K *)vy;
#ifdef z__AVX2__
const __m256 signBit = _mm256_set1_ps(-0.0f);
const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
for (int i = 0; i < nb; i++) {
const float * xb = x + i*QK_K;
__m256 maxAbs = _mm256_setzero_ps();
const float * xx = xb;
for (int ib = 0; ib < QK_K/8; ++ib) {
const __m256 v = _mm256_loadu_ps(xx); xx += 8;
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v));
}
const float maxScalar = hmax_f32_8(maxAbs);
const float d = maxScalar / 127.f;
y[i].d = d;
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
xx = xb;
int8_t * q8 = y[i].qs;
for (int ib = 0; ib < QK_K/32; ++ib) {
__m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
__m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
__m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
__m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST);
v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST);
v2 = _mm256_round_ps(v2, _MM_ROUND_NEAREST);
v3 = _mm256_round_ps(v3, _MM_ROUND_NEAREST);
__m256i i0 = _mm256_cvtps_epi32(v0);
__m256i i1 = _mm256_cvtps_epi32(v1);
__m256i i2 = _mm256_cvtps_epi32(v2);
__m256i i3 = _mm256_cvtps_epi32(v3);
y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1));
y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3));
i0 = _mm256_packs_epi32( i0, i1 );
i2 = _mm256_packs_epi32( i2, i3 );
i0 = _mm256_packs_epi16( i0, i2 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
_mm256_storeu_si256((__m256i *)q8, i0);
q8 += 32;
}
}
#else
for (int i = 0; i < nb; i++) {
float amax = 0;
for (int j = 0; j < QK_K; ++j) {
float ax = fabsf(x[j]);
amax = std::max(amax, ax);
}
std::memset(y[i].bsums, 0, 16*sizeof(int16_t));
if (!amax) {
y[i].d = 0;
std::memset(y[i].qs, 0, QK_K);
x += QK_K;
continue;
}
const float iscale = 127.f/amax;
for (int ib = 0; ib < QK_K/16; ++ib) {
int16_t sum = 0;
for (int k = 0; k < 4; ++k) {
int16_t v = nearest_int(iscale*x[4*ib + k + 0]);
sum += v;
y[i].qs[4*ib + k + 0] = v;
v = nearest_int(iscale*x[4*ib + k + 64]);
sum += v;
y[i].qs[4*ib + k + 64] = v;
v = nearest_int(iscale*x[4*ib + k + 128]);
sum += v;
y[i].qs[4*ib + k + 128] = v;
v = nearest_int(iscale*x[4*ib + k + 192]);
sum += v;
y[i].qs[4*ib + k + 192] = v;
}
y[i].bsums[ib] = sum;
}
//for (int j = 0; j < QK_K; ++j) {
// int16_t v = nearest_int(iscale*x[j]);
// y[i].qs[j] = v;
// int j4 = j/4;
// y[i].bsums[j4%16] += v;
//}
y[i].d = 1/iscale;
x += QK_K;
}
#endif
}
namespace {
static void quantize_row_iq4_k_impl_bs128(const int super_block_size, const int block_size,
int n_per_row, const float * x, char * cy,

View File

@@ -62,6 +62,7 @@ void dequantize_row_iq2_ks(const block_iq2_ks * GGML_RESTRICT x, float * GGML
void vec_dot_iq2_ks_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void iqk_quantize_row_q8_K16(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
#ifdef __cplusplus
}