mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-22 22:24:11 +00:00
Fix bug in the CPU flash attention implementation
This commit is contained in:
@@ -1412,7 +1412,9 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
const float * q, const char * mask, float * qkv,
|
||||
const float * sinkf, float * M, float * S, char * qptr) {
|
||||
auto q8 = (typename KHelper::block_q8 *)qptr;
|
||||
if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80>) {
|
||||
// This optimization fails under certain conditions (see https://github.com/ikawrakow/ik_llama.cpp/issues/1205)
|
||||
// => disabling until I figure out what goes wrong
|
||||
if constexpr (false && q_step > 1 && std::is_same_v<KHelper, HelperQ80>) {
|
||||
if (nq1 == q_step) {
|
||||
fms.init_qstep();
|
||||
kh.reset_block();
|
||||
@@ -1424,7 +1426,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
|
||||
if (Mc[0] != 0) break;
|
||||
if (k1 > 0 && Mc[0] != 0) break;
|
||||
HelperQ80R8<Dk>::repack(k_step, kh.block, kh.stride, q8r8);
|
||||
KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms);
|
||||
fqkv.accumulate_qkv(vh, fms);
|
||||
|
||||
Reference in New Issue
Block a user