mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
iq3_kt (0.3t/s eval) and renames
This commit is contained in:
@@ -3083,7 +3083,7 @@ static inline __m256 trellis_gen8(uint32_t val) {
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q2_KT_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
static void mul_mat_iq2_KT_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
|
||||
@@ -3154,7 +3154,7 @@ static void mul_mat_q2_KT_q8_K_T(int n, const void * vx, size_t bx, const DataIn
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_q2_KT_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
static void mul_mat_iq2_KT_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
|
||||
@@ -3200,6 +3200,89 @@ static void mul_mat_q2_KT_F32_T(int n, const void * vx, size_t bx, const DataInf
|
||||
}
|
||||
}
|
||||
|
||||
static inline __m256 abs_ps(__m256 vals) {
|
||||
// Clear sign-bit of all the 32-bit floats in vals
|
||||
__m256 sign_bit = _mm256_set1_ps(-0.0f);
|
||||
return _mm256_andnot_ps(sign_bit, vals);
|
||||
}
|
||||
|
||||
// Negates 32-bit float lanes of an 8x32-bit vector
|
||||
// based on 8x8-bit condition var. For float lane i, if byte i of
|
||||
// `condition` is nonzero, the float will be negated.
|
||||
static inline __m256 conditional_negate_ps(__m256 vals, uint64_t condition_mask_u64) {
|
||||
__m128i condition_bytes = _mm_set_epi64x(0, condition_mask_u64);
|
||||
// Make `should_negate_byte_mask` where byte i == 0xFF if byte i in condition_bytes is zero,
|
||||
// else 0x00 (upper bytes are meaningless)
|
||||
__m128i zeros = _mm_setzero_si128();
|
||||
__m128i is_zero_byte_mask = _mm_cmpeq_epi8(condition_bytes, zeros);
|
||||
__m128i should_negate_byte_mask = _mm_cmpeq_epi8(is_zero_byte_mask, zeros);
|
||||
// Widen lower 8x8 bits of `should_negate_byte_mask` to 8x32 bits by padding zeros
|
||||
// expanded_mask_epi32[j] will be 0x000000FF if vals[j] should be negated, zero otherwise
|
||||
__m256i expanded_mask_epi32 = _mm256_cvtepu8_epi32(should_negate_byte_mask);
|
||||
// Same as above but with all 32 bits of lane j set if vals[j] should be negated (use to make XOR mask)
|
||||
__m256i full_dword_negate_mask = _mm256_cmpgt_epi32(expanded_mask_epi32, _mm256_setzero_si256());
|
||||
// Negate via XOR on sign bits of each 32-bit float
|
||||
__m256i sign_bit_pattern = _mm256_set1_epi32(0x80000000); // MSB set for a 32-bit value
|
||||
__m256i xor_mask_epi32 = _mm256_and_si256(full_dword_negate_mask, sign_bit_pattern);
|
||||
__m256 xor_mask_ps = _mm256_castsi256_ps(xor_mask_epi32);
|
||||
return _mm256_xor_ps(vals, xor_mask_ps);
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq3_KT_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
|
||||
__m256 accd[nrc_y];
|
||||
const float * y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
const float d = *dptr * 31.75f * 1.015f;
|
||||
const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const uint16_t * ql = (const uint16_t *)x[i].ql;
|
||||
const uint8_t * qh = x[i].qh;
|
||||
for (int j = 0; j < 128; j+=8) {
|
||||
uint64_t mask1 = 0x0101010101010101 << (j/32);
|
||||
uint64_t mask2 = mask1 << 4;
|
||||
uint32_t val1 = ql[j/8] + 4096;
|
||||
uint32_t val2 = ql[j/8+16] + 4096;
|
||||
const uint64_t signs = *((const uint64_t *)(qh + (j%32)));
|
||||
const float x_scale1 = (x[i].scales[j/32] & 0xf);
|
||||
const float x_scale2 = (x[i].scales[j/32] >> 4);
|
||||
const __m256 x_val1 = abs_ps(trellis_gen8(val1));
|
||||
const __m256 x_val2 = abs_ps(trellis_gen8(val2));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
accd[iy] = _mm256_fmadd_ps(
|
||||
conditional_negate_ps(
|
||||
_mm256_load_ps(y[iy] + i*QK_K+j), signs & mask1
|
||||
),
|
||||
_mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
|
||||
accd[iy]
|
||||
);
|
||||
accd[iy] = _mm256_fmadd_ps(
|
||||
conditional_negate_ps(
|
||||
_mm256_load_ps(y[iy] + i*QK_K+j+128), signs & mask2
|
||||
),
|
||||
_mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
|
||||
accd[iy]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
__m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]);
|
||||
info.store(ix, iy, hsum_float_8(res));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // Zen4 or vanilla AVX2
|
||||
|
||||
template <int nrc_y>
|
||||
@@ -8944,22 +9027,34 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
break;
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
assert (ne00 % QK_K == 0);
|
||||
// mm.funcs[0] = mul_mat_q2_KT_q8_K_T<1>;
|
||||
// mm.funcs[1] = mul_mat_q2_KT_q8_K_T<2>;
|
||||
// mm.funcs[2] = mul_mat_q2_KT_q8_K_T<3>;
|
||||
// mm.funcs[3] = mul_mat_q2_KT_q8_K_T<4>;
|
||||
// mm.funcs[4] = mul_mat_q2_KT_q8_K_T<5>;
|
||||
// mm.funcs[5] = mul_mat_q2_KT_q8_K_T<6>;
|
||||
// mm.funcs[6] = mul_mat_q2_KT_q8_K_T<7>;
|
||||
// mm.funcs[7] = mul_mat_q2_KT_q8_K_T<8>;
|
||||
mm.funcs[0] = mul_mat_q2_KT_F32_T<1>;
|
||||
mm.funcs[1] = mul_mat_q2_KT_F32_T<2>;
|
||||
mm.funcs[2] = mul_mat_q2_KT_F32_T<3>;
|
||||
mm.funcs[3] = mul_mat_q2_KT_F32_T<4>;
|
||||
mm.funcs[4] = mul_mat_q2_KT_F32_T<5>;
|
||||
mm.funcs[5] = mul_mat_q2_KT_F32_T<6>;
|
||||
mm.funcs[6] = mul_mat_q2_KT_F32_T<7>;
|
||||
mm.funcs[7] = mul_mat_q2_KT_F32_T<8>;
|
||||
// mm.funcs[0] = mul_mat_iq2_KT_q8_K_T<1>;
|
||||
// mm.funcs[1] = mul_mat_iq2_KT_q8_K_T<2>;
|
||||
// mm.funcs[2] = mul_mat_iq2_KT_q8_K_T<3>;
|
||||
// mm.funcs[3] = mul_mat_iq2_KT_q8_K_T<4>;
|
||||
// mm.funcs[4] = mul_mat_iq2_KT_q8_K_T<5>;
|
||||
// mm.funcs[5] = mul_mat_iq2_KT_q8_K_T<6>;
|
||||
// mm.funcs[6] = mul_mat_iq2_KT_q8_K_T<7>;
|
||||
// mm.funcs[7] = mul_mat_iq2_KT_q8_K_T<8>;
|
||||
mm.funcs[0] = mul_mat_iq2_KT_F32_T<1>;
|
||||
mm.funcs[1] = mul_mat_iq2_KT_F32_T<2>;
|
||||
mm.funcs[2] = mul_mat_iq2_KT_F32_T<3>;
|
||||
mm.funcs[3] = mul_mat_iq2_KT_F32_T<4>;
|
||||
mm.funcs[4] = mul_mat_iq2_KT_F32_T<5>;
|
||||
mm.funcs[5] = mul_mat_iq2_KT_F32_T<6>;
|
||||
mm.funcs[6] = mul_mat_iq2_KT_F32_T<7>;
|
||||
mm.funcs[7] = mul_mat_iq2_KT_F32_T<8>;
|
||||
expected_typeB = GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_TYPE_IQ3_KT:
|
||||
assert (ne00 % QK_K == 0);
|
||||
mm.funcs[0] = mul_mat_iq3_KT_F32_T<1>;
|
||||
mm.funcs[1] = mul_mat_iq3_KT_F32_T<2>;
|
||||
mm.funcs[2] = mul_mat_iq3_KT_F32_T<3>;
|
||||
mm.funcs[3] = mul_mat_iq3_KT_F32_T<4>;
|
||||
mm.funcs[4] = mul_mat_iq3_KT_F32_T<5>;
|
||||
mm.funcs[5] = mul_mat_iq3_KT_F32_T<6>;
|
||||
mm.funcs[6] = mul_mat_iq3_KT_F32_T<7>;
|
||||
mm.funcs[7] = mul_mat_iq3_KT_F32_T<8>;
|
||||
expected_typeB = GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_TYPE_IQ3_K:
|
||||
|
||||
Reference in New Issue
Block a user