mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 03:11:51 +00:00
wip buggy iq4_KT
This commit is contained in:
@@ -3082,6 +3082,42 @@ static inline __m256 trellis_gen8(uint32_t val) {
|
||||
return _mm256_add_ps(f1, f2);
|
||||
}
|
||||
|
||||
static inline __m256i trellis_next8(uint32_t val1, uint32_t val2) {
|
||||
constexpr uint32_t kmask = 0x8fff8fff;
|
||||
constexpr uint32_t km32 = 0x3b603b60;
|
||||
constexpr uint32_t ka = 89226354;
|
||||
constexpr uint32_t kb = 64248484;
|
||||
constexpr uint32_t ka1 = ka*ka;
|
||||
constexpr uint32_t kb1 = kb*ka+kb;
|
||||
constexpr uint32_t ka2 = ka1*ka;
|
||||
constexpr uint32_t kb2 = kb1*ka+kb;
|
||||
constexpr uint32_t ka3 = ka2*ka;
|
||||
constexpr uint32_t kb3 = kb2*ka+kb;
|
||||
__m256i mka = _mm256_setr_epi32(ka, ka1, ka, ka1, ka2, ka3, ka2, ka3);
|
||||
__m256i mkb = _mm256_setr_epi32(kb, kb1, kb, kb1, kb2, kb3, kb2, kb3);
|
||||
__m256i mval = _mm256_setr_epi32(val1, val1, val2, val2, val1, val1, val2, val2);
|
||||
__m256i mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb);
|
||||
return _mm256_and_si256(mres, _mm256_set1_epi32(kmask)) ^ _mm256_set1_epi32(km32);
|
||||
}
|
||||
|
||||
static inline __m256 trellis_gen8(uint32_t val1, uint32_t val2) {
|
||||
__m256i i8 = trellis_next8(val1, val2);
|
||||
// split upper and lower bits of each 32-bit lane into two 8xfloat16 `hlo`, `hhi`
|
||||
__m256i low_16_bits_mask = _mm256_set1_epi32(0x0000FFFF);
|
||||
__m256i lower_halves_lanes32 = _mm256_and_si256(i8, low_16_bits_mask);
|
||||
__m256i upper_halves_lanes32 = _mm256_srli_epi32(i8, 16);
|
||||
__m128i lo0123 = _mm256_extracti128_si256(lower_halves_lanes32, 0); // Extracts [00L0, 00L1, 00L2, 00L3]
|
||||
__m128i lo4567 = _mm256_extracti128_si256(lower_halves_lanes32, 1); // Extracts [00L4, 00L5, 00L6, 00L7]
|
||||
__m128i hlo = _mm_packus_epi32(lo0123, lo4567);
|
||||
__m128i hi0123 = _mm256_extracti128_si256(upper_halves_lanes32, 0); // Extracts [00H0, 00H1, 00H2, 00H3]
|
||||
__m128i hi4567 = _mm256_extracti128_si256(upper_halves_lanes32, 1); // Extracts [00H4, 00H5, 00H6, 00H7]
|
||||
__m128i hhi = _mm_packus_epi32(hi0123, hi4567);
|
||||
// widen both to 8xfloat32 and sum
|
||||
__m256 f1 = _mm256_cvtph_ps(hlo);
|
||||
__m256 f2 = _mm256_cvtph_ps(hhi);
|
||||
return _mm256_add_ps(f1, f2);
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2_KT_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
@@ -3283,6 +3319,76 @@ static void mul_mat_iq3_KT_F32_T(int n, const void * vx, size_t bx, const DataIn
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq4_KT_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
constexpr int kNumGroups = 64;
|
||||
|
||||
__m256 accd[nrc_y];
|
||||
__m256 accd2[nrc_y];
|
||||
const float * y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
const float d = dptr[0] * 31.75f * 1.01f;
|
||||
const float row_av = dptr[1];
|
||||
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
accd[iy] = _mm256_setzero_ps();
|
||||
accd2[iy] = _mm256_setzero_ps();
|
||||
}
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const uint32_t * shb = x[i].qs;
|
||||
const uint8_t * ql = (const uint8_t *)(shb + 8);
|
||||
const uint8_t * qh = ql + kNumGroups;
|
||||
for (int j = 0; j < 128; j+=8) {
|
||||
const uint32_t offset1 = 4096 + ((shb[j/32+0] & 1) << 15);
|
||||
const uint32_t offset2 = 4096 + ((shb[j/32+4] & 1) << 15);
|
||||
const float x_scale1 = (int)((shb[j/32+0] & 0xff) >> 1) - 64;
|
||||
const float x_scale2 = (int)((shb[j/32+4] & 0xff) >> 1) - 64;
|
||||
const uint32_t sh1 = shb[j/32+0] >> (8 + 6*((j/8)%4));
|
||||
const uint32_t sh2 = shb[j/32+4] >> (8 + 6*((j/8)%4));
|
||||
uint32_t val1 = ql[j/4+ 0] + ((qh[j/4+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1;
|
||||
uint32_t val2 = ql[j/4+32] + ((qh[j/4+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2;
|
||||
uint32_t val3 = ql[j/4+ 1] + ((qh[j/4+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1;
|
||||
uint32_t val4 = ql[j/4+33] + ((qh[j/4+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2;
|
||||
const __m256 x_val1 = trellis_gen8(val1, val3);
|
||||
const __m256 x_val2 = trellis_gen8(val2, val4);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
accd[iy] = _mm256_fmadd_ps(
|
||||
_mm256_load_ps(y[iy] + i*QK_K+j),
|
||||
_mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
|
||||
accd[iy]
|
||||
);
|
||||
accd[iy] = _mm256_fmadd_ps(
|
||||
_mm256_load_ps(y[iy] + i*QK_K+j+128),
|
||||
_mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
|
||||
accd[iy]
|
||||
);
|
||||
accd2[iy] = _mm256_add_ps(
|
||||
_mm256_load_ps(y[iy] + i*QK_K+j),
|
||||
accd2[iy]
|
||||
);
|
||||
accd2[iy] = _mm256_add_ps(
|
||||
_mm256_load_ps(y[iy] + i*QK_K+j+128),
|
||||
accd2[iy]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
__m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]);
|
||||
__m256 res2 = _mm256_mul_ps(_mm256_set1_ps(row_av), accd2[iy]);
|
||||
info.store(ix, iy, hsum_float_8(res) + hsum_float_8(res2));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // Zen4 or vanilla AVX2
|
||||
|
||||
template <int nrc_y>
|
||||
@@ -9057,6 +9163,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
mm.funcs[7] = mul_mat_iq3_KT_F32_T<8>;
|
||||
expected_typeB = GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KT:
|
||||
assert (ne00 % QK_K == 0);
|
||||
mm.funcs[0] = mul_mat_iq4_KT_F32_T<1>;
|
||||
mm.funcs[1] = mul_mat_iq4_KT_F32_T<2>;
|
||||
mm.funcs[2] = mul_mat_iq4_KT_F32_T<3>;
|
||||
mm.funcs[3] = mul_mat_iq4_KT_F32_T<4>;
|
||||
mm.funcs[4] = mul_mat_iq4_KT_F32_T<5>;
|
||||
mm.funcs[5] = mul_mat_iq4_KT_F32_T<6>;
|
||||
mm.funcs[6] = mul_mat_iq4_KT_F32_T<7>;
|
||||
mm.funcs[7] = mul_mat_iq4_KT_F32_T<8>;
|
||||
expected_typeB = GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_TYPE_IQ3_K:
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ3K>(mm);
|
||||
|
||||
Reference in New Issue
Block a user