iq3_kt (0.3t/s eval) and renames

This commit is contained in:
Andrew Keen Chan
2025-05-19 03:03:05 +00:00
parent c4e5d3e382
commit 04eb150b9f

View File

@@ -3083,7 +3083,7 @@ static inline __m256 trellis_gen8(uint32_t val) {
}
template <int nrc_y>
static void mul_mat_q2_KT_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
static void mul_mat_iq2_KT_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
const int nb = n/QK_K;
@@ -3154,7 +3154,7 @@ static void mul_mat_q2_KT_q8_K_T(int n, const void * vx, size_t bx, const DataIn
}
template <int nrc_y>
static void mul_mat_q2_KT_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
static void mul_mat_iq2_KT_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
const int nb = n/QK_K;
@@ -3200,6 +3200,89 @@ static void mul_mat_q2_KT_F32_T(int n, const void * vx, size_t bx, const DataInf
}
}
static inline __m256 abs_ps(__m256 vals) {
// Clear sign-bit of all the 32-bit floats in vals
__m256 sign_bit = _mm256_set1_ps(-0.0f);
return _mm256_andnot_ps(sign_bit, vals);
}
// Negates 32-bit float lanes of an 8x32-bit vector
// based on 8x8-bit condition var. For float lane i, if byte i of
// `condition` is nonzero, the float will be negated.
static inline __m256 conditional_negate_ps(__m256 vals, uint64_t condition_mask_u64) {
__m128i condition_bytes = _mm_set_epi64x(0, condition_mask_u64);
// Make `should_negate_byte_mask` where byte i == 0xFF if byte i in condition_bytes is zero,
// else 0x00 (upper bytes are meaningless)
__m128i zeros = _mm_setzero_si128();
__m128i is_zero_byte_mask = _mm_cmpeq_epi8(condition_bytes, zeros);
__m128i should_negate_byte_mask = _mm_cmpeq_epi8(is_zero_byte_mask, zeros);
// Widen lower 8x8 bits of `should_negate_byte_mask` to 8x32 bits by padding zeros
// expanded_mask_epi32[j] will be 0x000000FF if vals[j] should be negated, zero otherwise
__m256i expanded_mask_epi32 = _mm256_cvtepu8_epi32(should_negate_byte_mask);
// Same as above but with all 32 bits of lane j set if vals[j] should be negated (use to make XOR mask)
__m256i full_dword_negate_mask = _mm256_cmpgt_epi32(expanded_mask_epi32, _mm256_setzero_si256());
// Negate via XOR on sign bits of each 32-bit float
__m256i sign_bit_pattern = _mm256_set1_epi32(0x80000000); // MSB set for a 32-bit value
__m256i xor_mask_epi32 = _mm256_and_si256(full_dword_negate_mask, sign_bit_pattern);
__m256 xor_mask_ps = _mm256_castsi256_ps(xor_mask_epi32);
return _mm256_xor_ps(vals, xor_mask_ps);
}
template <int nrc_y>
static void mul_mat_iq3_KT_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
const int nb = n/QK_K;
__m256 accd[nrc_y];
const float * y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char*)vx + ix*bx);
const float d = *dptr * 31.75f * 1.015f;
const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1);
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const uint16_t * ql = (const uint16_t *)x[i].ql;
const uint8_t * qh = x[i].qh;
for (int j = 0; j < 128; j+=8) {
uint64_t mask1 = 0x0101010101010101 << (j/32);
uint64_t mask2 = mask1 << 4;
uint32_t val1 = ql[j/8] + 4096;
uint32_t val2 = ql[j/8+16] + 4096;
const uint64_t signs = *((const uint64_t *)(qh + (j%32)));
const float x_scale1 = (x[i].scales[j/32] & 0xf);
const float x_scale2 = (x[i].scales[j/32] >> 4);
const __m256 x_val1 = abs_ps(trellis_gen8(val1));
const __m256 x_val2 = abs_ps(trellis_gen8(val2));
for (int iy = 0; iy < nrc_y; ++iy) {
accd[iy] = _mm256_fmadd_ps(
conditional_negate_ps(
_mm256_load_ps(y[iy] + i*QK_K+j), signs & mask1
),
_mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
accd[iy]
);
accd[iy] = _mm256_fmadd_ps(
conditional_negate_ps(
_mm256_load_ps(y[iy] + i*QK_K+j+128), signs & mask2
),
_mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
accd[iy]
);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
__m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]);
info.store(ix, iy, hsum_float_8(res));
}
}
}
#endif // Zen4 or vanilla AVX2
template <int nrc_y>
@@ -8944,22 +9027,34 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
break;
case GGML_TYPE_IQ2_KT:
assert (ne00 % QK_K == 0);
// mm.funcs[0] = mul_mat_q2_KT_q8_K_T<1>;
// mm.funcs[1] = mul_mat_q2_KT_q8_K_T<2>;
// mm.funcs[2] = mul_mat_q2_KT_q8_K_T<3>;
// mm.funcs[3] = mul_mat_q2_KT_q8_K_T<4>;
// mm.funcs[4] = mul_mat_q2_KT_q8_K_T<5>;
// mm.funcs[5] = mul_mat_q2_KT_q8_K_T<6>;
// mm.funcs[6] = mul_mat_q2_KT_q8_K_T<7>;
// mm.funcs[7] = mul_mat_q2_KT_q8_K_T<8>;
mm.funcs[0] = mul_mat_q2_KT_F32_T<1>;
mm.funcs[1] = mul_mat_q2_KT_F32_T<2>;
mm.funcs[2] = mul_mat_q2_KT_F32_T<3>;
mm.funcs[3] = mul_mat_q2_KT_F32_T<4>;
mm.funcs[4] = mul_mat_q2_KT_F32_T<5>;
mm.funcs[5] = mul_mat_q2_KT_F32_T<6>;
mm.funcs[6] = mul_mat_q2_KT_F32_T<7>;
mm.funcs[7] = mul_mat_q2_KT_F32_T<8>;
// mm.funcs[0] = mul_mat_iq2_KT_q8_K_T<1>;
// mm.funcs[1] = mul_mat_iq2_KT_q8_K_T<2>;
// mm.funcs[2] = mul_mat_iq2_KT_q8_K_T<3>;
// mm.funcs[3] = mul_mat_iq2_KT_q8_K_T<4>;
// mm.funcs[4] = mul_mat_iq2_KT_q8_K_T<5>;
// mm.funcs[5] = mul_mat_iq2_KT_q8_K_T<6>;
// mm.funcs[6] = mul_mat_iq2_KT_q8_K_T<7>;
// mm.funcs[7] = mul_mat_iq2_KT_q8_K_T<8>;
mm.funcs[0] = mul_mat_iq2_KT_F32_T<1>;
mm.funcs[1] = mul_mat_iq2_KT_F32_T<2>;
mm.funcs[2] = mul_mat_iq2_KT_F32_T<3>;
mm.funcs[3] = mul_mat_iq2_KT_F32_T<4>;
mm.funcs[4] = mul_mat_iq2_KT_F32_T<5>;
mm.funcs[5] = mul_mat_iq2_KT_F32_T<6>;
mm.funcs[6] = mul_mat_iq2_KT_F32_T<7>;
mm.funcs[7] = mul_mat_iq2_KT_F32_T<8>;
expected_typeB = GGML_TYPE_F32;
break;
case GGML_TYPE_IQ3_KT:
assert (ne00 % QK_K == 0);
mm.funcs[0] = mul_mat_iq3_KT_F32_T<1>;
mm.funcs[1] = mul_mat_iq3_KT_F32_T<2>;
mm.funcs[2] = mul_mat_iq3_KT_F32_T<3>;
mm.funcs[3] = mul_mat_iq3_KT_F32_T<4>;
mm.funcs[4] = mul_mat_iq3_KT_F32_T<5>;
mm.funcs[5] = mul_mat_iq3_KT_F32_T<6>;
mm.funcs[6] = mul_mat_iq3_KT_F32_T<7>;
mm.funcs[7] = mul_mat_iq3_KT_F32_T<8>;
expected_typeB = GGML_TYPE_F32;
break;
case GGML_TYPE_IQ3_K: