Fix bug in the CPU flash attention implementation

This commit is contained in:
Kawrakow
2026-01-30 09:50:48 +02:00
parent 686fd1ebec
commit efd331f3eb

View File

@@ -1412,7 +1412,9 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
const float * q, const char * mask, float * qkv,
const float * sinkf, float * M, float * S, char * qptr) {
auto q8 = (typename KHelper::block_q8 *)qptr;
if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80>) {
// This optimization fails under certain conditions (see https://github.com/ikawrakow/ik_llama.cpp/issues/1205)
// => disabling until I figure out what goes wrong
if constexpr (false && q_step > 1 && std::is_same_v<KHelper, HelperQ80>) {
if (nq1 == q_step) {
fms.init_qstep();
kh.reset_block();
@@ -1424,7 +1426,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
if (Mc[0] != 0) break;
if (k1 > 0 && Mc[0] != 0) break;
HelperQ80R8<Dk>::repack(k_step, kh.block, kh.stride, q8r8);
KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms);
fqkv.accumulate_qkv(vh, fms);