mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-19 04:40:09 +00:00
Revert the commented out section in iqk_mul_mat.cpp
It does have some benefit at long contexts.
This commit is contained in:
@@ -17103,25 +17103,23 @@ template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
|
||||
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
|
||||
|
||||
// Not sure if this actually helps.
|
||||
// So, let's reduce compilation time by commenting it out for now.
|
||||
//if (nk1 >= 256) { //4096) {
|
||||
// if (nq1 >= 64) {
|
||||
// FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
|
||||
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
// return;
|
||||
// }
|
||||
// if (nq1 >= 32) {
|
||||
// FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
|
||||
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
// return;
|
||||
// }
|
||||
// if (nq1 >= 16) {
|
||||
// FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
|
||||
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
// return;
|
||||
// }
|
||||
//}
|
||||
if (nk1 >= 256) { //4096) {
|
||||
if (nq1 >= 64) {
|
||||
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
return;
|
||||
}
|
||||
if (nq1 >= 32) {
|
||||
FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
return;
|
||||
}
|
||||
if (nq1 >= 16) {
|
||||
FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (nq1 >= 8) {
|
||||
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
|
||||
Reference in New Issue
Block a user