mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-01 01:24:08 +00:00
Better CPU FA thread strategy
This commit is contained in:
@@ -25259,6 +25259,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
if (k->type == GGML_TYPE_Q8_0) {
|
||||
qsize = ggml_nrows(k)*ggml_row_size(k->type, k->ne[0]);
|
||||
}
|
||||
int nstep_k = k->ne[1]/32;
|
||||
if (nstep_k >= 4*n_tasks && q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1) {
|
||||
size_t size_thread = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
|
||||
size_t size = size_thread*n_tasks;
|
||||
cur = MAX(cur, size+qsize);
|
||||
} else {
|
||||
if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
|
||||
if (k->ne[2] > 1) {
|
||||
int gcd = simple_gcd(k->ne[2], n_tasks);
|
||||
@@ -25279,12 +25285,17 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
nk = 32*nm;
|
||||
}
|
||||
//int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks));
|
||||
int nstep_k = k->ne[2]*k->ne[1]/nk;
|
||||
nstep_k = k->ne[2]*k->ne[1]/nk;
|
||||
size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
|
||||
size_t size = nstep_k*result_size;
|
||||
cur = MAX(cur, size+qsize);
|
||||
} else {
|
||||
int nstep_k = k->ne[1]/32;
|
||||
nstep_k = k->ne[1]/32;
|
||||
if (nstep_k >= n_tasks) {
|
||||
size_t size_thread = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
|
||||
size_t size = size_thread*n_tasks;
|
||||
cur = MAX(cur, size+qsize);
|
||||
} else {
|
||||
int gcd_k = simple_gcd(nstep_k, n_tasks);
|
||||
if (gcd_k > 1) {
|
||||
int nth_k = n_tasks/gcd_k;
|
||||
@@ -25298,10 +25309,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
}
|
||||
cur = MAX(cur, size+qsize);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cur = MAX(cur, qsize);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
|
||||
@@ -38,14 +38,17 @@ inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
|
||||
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(4*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 2) {
|
||||
int n_step = nq1/2;
|
||||
FlashAttn<576, 512, 2, step_k> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(2*n_step)) return;
|
||||
if (nq1 == 3) {
|
||||
FlashAttn<576, 512, 3, step_k> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 3, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
}
|
||||
else if (nq1 == 2) {
|
||||
FlashAttn<576, 512, 2, step_k> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 2, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
} else {
|
||||
FlashAttn<576, 512, 1, step_k> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
}
|
||||
FlashAttn<576, 512, 1, step_k> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
}
|
||||
|
||||
template <int step_k>
|
||||
|
||||
@@ -1414,7 +1414,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
auto q8 = (typename KHelper::block_q8 *)qptr;
|
||||
// This optimization fails under certain conditions (see https://github.com/ikawrakow/ik_llama.cpp/issues/1205)
|
||||
// => disabling until I figure out what goes wrong
|
||||
if constexpr (false && q_step > 1 && std::is_same_v<KHelper, HelperQ80>) {
|
||||
if constexpr (q_step >= 4 && std::is_same_v<KHelper, HelperQ80>) {
|
||||
if (nq1 == q_step) {
|
||||
fms.init_qstep();
|
||||
kh.reset_block();
|
||||
@@ -1424,9 +1424,11 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
auto q8r = (typename HelperQ80R8<Dk>::block_q8 *)qptr;
|
||||
HelperQ80::convert<Dk>(q_step, stride_q, q, q8r);
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
|
||||
if (k1 > 0 && Mc[0] != 0) break;
|
||||
auto Mc = (const uint16_t *)(mr + (q_step - 1)*stride_m);
|
||||
int ik = nk1 - k_step;
|
||||
for (; ik >=0 && Mc[ik] != 0; ik -= k_step);
|
||||
ik += k_step;
|
||||
for (int k1 = 0; k1 < ik/k_step; ++k1) {
|
||||
HelperQ80R8<Dk>::repack(k_step, kh.block, kh.stride, q8r8);
|
||||
KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms);
|
||||
fqkv.accumulate_qkv(vh, fms);
|
||||
|
||||
@@ -145,6 +145,57 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
// I think it would also speed up things for GQA, but I'm leaving this for another day.
|
||||
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nth >= 1 && nek1/32 > 1 && nek2 == 1) {
|
||||
int nstep_k = nek1/32;
|
||||
//if (ith >= nstep_k && ith >= rk2) return true;
|
||||
if (nstep_k >= nth) { //4*nth) {
|
||||
int nstep_k_per_thread = (nstep_k + nth - 1)/nth;
|
||||
int ith_mid = nth;
|
||||
int nstep_k_this_thread = nstep_k_per_thread;
|
||||
if (nstep_k_per_thread*nth > nstep_k) {
|
||||
ith_mid = nstep_k - nth*(nstep_k_per_thread - 1);
|
||||
if (ith >= ith_mid) --nstep_k_this_thread;
|
||||
}
|
||||
//if (ith == 0) fprintf(stderr, "nstep_k = %d, nstep_k_per_thread = %d, ith_mid = %d\n", nstep_k, nstep_k_per_thread, ith_mid);
|
||||
nstep_k_per_thread *= 32;
|
||||
nstep_k_this_thread *= 32;
|
||||
|
||||
auto kv_offset = ith <= ith_mid ? ith*nstep_k_per_thread
|
||||
: ith_mid*nstep_k_per_thread + (ith - ith_mid)*nstep_k_this_thread;
|
||||
auto kth = (const char *)k + kv_offset*stride_k;
|
||||
auto vth = (const char *)v + kv_offset*stride_v;
|
||||
auto qth = (const char *)q;
|
||||
auto mth = (const char *)mask + kv_offset*sizeof(uint16_t); // we don't have ggml_half available here
|
||||
|
||||
auto work = (char *)work_buffer;
|
||||
auto size_thread = (Dv + 16)*rk2*sizeof(float);
|
||||
auto result_buffer = work;
|
||||
auto work_this_thread = (float *)(result_buffer + ith*size_thread);
|
||||
//if (nstep_k_this_thread > 0) {
|
||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
||||
Dk, Dv, rk2, nstep_k_this_thread, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
|
||||
(const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, nullptr, 0,
|
||||
scale, softcap,
|
||||
work_this_thread, work_this_thread + (Dv+0)*rk2, work_this_thread + (Dv+1)*rk2)) return false;
|
||||
//}
|
||||
|
||||
barrier(barrier_data);
|
||||
|
||||
//int nhave = std::min(nstep_k, nth);
|
||||
for (int j = ith; j < rk2; j += nth) {
|
||||
auto Racc = qkv + j*nb1/sizeof(float);
|
||||
float M = -INFINITY, S = 0;
|
||||
for (int jth = 0; jth < nth; ++jth) {
|
||||
//for (int jth = 0; jth < nhave; ++jth) {
|
||||
auto R = (const float *)(result_buffer + jth*size_thread);
|
||||
auto Mj = R + Dv*rk2;
|
||||
auto Sj = Mj + rk2;
|
||||
R += j*Dv;
|
||||
accumulate_qkv(Dv, M, S, Mj[j], Sj[j], Racc, R);
|
||||
}
|
||||
float norm = S > 0 ? 1/S : 1;
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
int gcd_k = simple_gcd(nstep_k, nth);
|
||||
if (gcd_k >= 1) {
|
||||
int nth_k = nth/gcd_k;
|
||||
@@ -312,6 +363,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
if (counter++ % (nth/ntg) == ith/ntg) {
|
||||
int iq1 = (ith%ntg)*neq1g;
|
||||
int this_neq1 = std::min(neq1g, neq1-iq1);
|
||||
if (this_neq1 > 0) {
|
||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
||||
Dk, Dv, this_neq1, nek1, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float),
|
||||
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
|
||||
@@ -320,6 +372,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
(const void *)((const char *)mask + iq1*stride_m), sinksf, 1,
|
||||
scale, softcap,
|
||||
(float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user