Better CPU FA thread strategy

This commit is contained in:
Kawrakow
2026-01-31 15:46:16 +00:00
parent 33308908db
commit 2bf2fa8ba4
4 changed files with 84 additions and 13 deletions

View File

@@ -25259,6 +25259,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
if (k->type == GGML_TYPE_Q8_0) {
qsize = ggml_nrows(k)*ggml_row_size(k->type, k->ne[0]);
}
int nstep_k = k->ne[1]/32;
if (nstep_k >= 4*n_tasks && q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1) {
size_t size_thread = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
size_t size = size_thread*n_tasks;
cur = MAX(cur, size+qsize);
} else {
if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
if (k->ne[2] > 1) {
int gcd = simple_gcd(k->ne[2], n_tasks);
@@ -25279,12 +25285,17 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
nk = 32*nm;
}
//int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks));
int nstep_k = k->ne[2]*k->ne[1]/nk;
nstep_k = k->ne[2]*k->ne[1]/nk;
size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
size_t size = nstep_k*result_size;
cur = MAX(cur, size+qsize);
} else {
int nstep_k = k->ne[1]/32;
nstep_k = k->ne[1]/32;
if (nstep_k >= n_tasks) {
size_t size_thread = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
size_t size = size_thread*n_tasks;
cur = MAX(cur, size+qsize);
} else {
int gcd_k = simple_gcd(nstep_k, n_tasks);
if (gcd_k > 1) {
int nth_k = n_tasks/gcd_k;
@@ -25298,10 +25309,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
}
cur = MAX(cur, size+qsize);
}
}
}
} else {
cur = MAX(cur, qsize);
}
}
#endif
} break;
case GGML_OP_FLASH_ATTN_BACK:

View File

@@ -38,14 +38,17 @@ inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(4*n_step)) return;
}
if (nq1 >= 2) {
int n_step = nq1/2;
FlashAttn<576, 512, 2, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(2*n_step)) return;
if (nq1 == 3) {
FlashAttn<576, 512, 3, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 3, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
}
else if (nq1 == 2) {
FlashAttn<576, 512, 2, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 2, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
} else {
FlashAttn<576, 512, 1, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
}
FlashAttn<576, 512, 1, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
}
template <int step_k>

View File

@@ -1414,7 +1414,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
auto q8 = (typename KHelper::block_q8 *)qptr;
// This optimization fails under certain conditions (see https://github.com/ikawrakow/ik_llama.cpp/issues/1205)
// => disabling until I figure out what goes wrong
if constexpr (false && q_step > 1 && std::is_same_v<KHelper, HelperQ80>) {
if constexpr (q_step >= 4 && std::is_same_v<KHelper, HelperQ80>) {
if (nq1 == q_step) {
fms.init_qstep();
kh.reset_block();
@@ -1424,9 +1424,11 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
auto q8r = (typename HelperQ80R8<Dk>::block_q8 *)qptr;
HelperQ80::convert<Dk>(q_step, stride_q, q, q8r);
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
if (k1 > 0 && Mc[0] != 0) break;
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
int ik = nk1 - k_step;
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
ik += k_step;
for (int k1 = 0; k1 < ik/k_step; ++k1) {
HelperQ80R8<Dk>::repack(k_step, kh.block, kh.stride, q8r8);
KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms);
fqkv.accumulate_qkv(vh, fms);

View File

@@ -145,6 +145,57 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
// I think it would also speed up things for GQA, but I'm leaving this for another day.
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nth >= 1 && nek1/32 > 1 && nek2 == 1) {
int nstep_k = nek1/32;
//if (ith >= nstep_k && ith >= rk2) return true;
if (nstep_k >= nth) { //4*nth) {
int nstep_k_per_thread = (nstep_k + nth - 1)/nth;
int ith_mid = nth;
int nstep_k_this_thread = nstep_k_per_thread;
if (nstep_k_per_thread*nth > nstep_k) {
ith_mid = nstep_k - nth*(nstep_k_per_thread - 1);
if (ith >= ith_mid) --nstep_k_this_thread;
}
//if (ith == 0) fprintf(stderr, "nstep_k = %d, nstep_k_per_thread = %d, ith_mid = %d\n", nstep_k, nstep_k_per_thread, ith_mid);
nstep_k_per_thread *= 32;
nstep_k_this_thread *= 32;
auto kv_offset = ith <= ith_mid ? ith*nstep_k_per_thread
: ith_mid*nstep_k_per_thread + (ith - ith_mid)*nstep_k_this_thread;
auto kth = (const char *)k + kv_offset*stride_k;
auto vth = (const char *)v + kv_offset*stride_v;
auto qth = (const char *)q;
auto mth = (const char *)mask + kv_offset*sizeof(uint16_t); // we don't have ggml_half available here
auto work = (char *)work_buffer;
auto size_thread = (Dv + 16)*rk2*sizeof(float);
auto result_buffer = work;
auto work_this_thread = (float *)(result_buffer + ith*size_thread);
//if (nstep_k_this_thread > 0) {
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, rk2, nstep_k_this_thread, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
(const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, nullptr, 0,
scale, softcap,
work_this_thread, work_this_thread + (Dv+0)*rk2, work_this_thread + (Dv+1)*rk2)) return false;
//}
barrier(barrier_data);
//int nhave = std::min(nstep_k, nth);
for (int j = ith; j < rk2; j += nth) {
auto Racc = qkv + j*nb1/sizeof(float);
float M = -INFINITY, S = 0;
for (int jth = 0; jth < nth; ++jth) {
//for (int jth = 0; jth < nhave; ++jth) {
auto R = (const float *)(result_buffer + jth*size_thread);
auto Mj = R + Dv*rk2;
auto Sj = Mj + rk2;
R += j*Dv;
accumulate_qkv(Dv, M, S, Mj[j], Sj[j], Racc, R);
}
float norm = S > 0 ? 1/S : 1;
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
}
return true;
}
int gcd_k = simple_gcd(nstep_k, nth);
if (gcd_k >= 1) {
int nth_k = nth/gcd_k;
@@ -312,6 +363,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
if (counter++ % (nth/ntg) == ith/ntg) {
int iq1 = (ith%ntg)*neq1g;
int this_neq1 = std::min(neq1g, neq1-iq1);
if (this_neq1 > 0) {
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, this_neq1, nek1, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float),
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
@@ -320,6 +372,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
(const void *)((const char *)mask + iq1*stride_m), sinksf, 1,
scale, softcap,
(float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false;
}
}
}
}