Improve DeepSeek batched processing speed (#282)

* Improve DeepSeek batched processing speed

* Revert the commented out section in iqk_mul_mat.cpp

It does have some benefit at long contexts.

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-03-23 17:10:52 +01:00
committed by GitHub
parent 5a4855e61c
commit f9307d7907
2 changed files with 15 additions and 3 deletions

View File

@@ -17265,13 +17265,25 @@ template <int step_k, typename KHelper, typename VHelper>
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
if (nq1 % 8 == 0) {
if (nq1 >= 8) {
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
} else {
}
else if (nq1 >= 4) {
FlashAttn<576, 512, 4, step_k> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
}
else {
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
}
//if (nq1 % 8 == 0) {
// FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
//} else {
// FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
//}
}
template <int step_k>