Use mul_mat_qX_0_q8_2_Tx for q6_0 in FA

This commit is contained in:
Iwan Kawrakow
2025-04-23 14:12:08 +03:00
parent 9f310ea663
commit ddcdf25e54

View File

@@ -16841,6 +16841,9 @@ struct FlashQKfp32 {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
#else
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 1, k_step>, 1);
if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 2, k_step>, 2);
if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q6_0_1_Unpacker, 4, k_step>, 4);
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q6_0_1_Unpacker, nq);
#endif
}
@@ -17094,7 +17097,7 @@ struct FlashAttn {
std::is_same_v<KHelper, HelperQ80<Dk, k_step>> ||
std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>> ||
std::is_same_v<KHelper, HelperQ8KVR8<Dk, k_step>>) {
constexpr size_t kMaxOnStackSize = 18432; //576;
constexpr size_t kMaxOnStackSize = 576;
auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8);
q_size = GGML_PAD(q_size, 64);
if (q_size > kMaxOnStackSize) {