Handle rk2%nth_k != 0

This commit is contained in:
Iwan Kawrakow
2025-03-22 11:53:15 +02:00
parent ece257f645
commit 988be1f8f0
2 changed files with 111 additions and 83 deletions

View File

@@ -21771,8 +21771,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
if (gcd_k > 1) {
int nth_k = n_tasks/gcd_k;
int rk2 = q->ne[2]/k->ne[2];
if (rk2%nth_k == 0) {
size_t size = (Dv + 16)*rk2/nth_k*sizeof(float)*n_tasks;
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks;
if (ggml_is_quantized(k->type)) {
enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
@@ -21781,7 +21781,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
cur = MAX(cur, size);
}
}
}
#endif
} break;
case GGML_OP_FLASH_ATTN_BACK:

View File

@@ -64,19 +64,32 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
int gcd_k = simple_gcd(nstep_k, nth);
if (gcd_k >= 1) {
int nth_k = nth/gcd_k;
if (rk2%nth_k == 0) {
int ith_k = ith%gcd_k;
int ith_q = ith/gcd_k;
// nth = 24, nek1 = 256, rk2 = 16 -> gcd_k = 8, nth_k = 3, nq_per_thread = 6
// nq_per_thread*nth_k = 18 > 16 -> ith_mid = 1, nq_this_thread = 5 for ith_q >= 1, j_mid = 6
int nq_per_thread = (rk2 + nth_k - 1)/nth_k;
int ith_mid = nth_k;
int nq_this_thread = nq_per_thread;
if (nq_per_thread*nth_k > rk2) {
// ith_mid*nq_per_thread + (nth_k - ith_mid)*(nq_per_thread - 1) = rk2
// -> ith_mid = rk2 - nth_k*(nq_per_thread - 1)
ith_mid = rk2 - nth_k*(nq_per_thread - 1);
if (ith_q >= ith_mid) --nq_this_thread;
}
int j_mid = ith_mid*nq_per_thread;
auto work = (char *)work_buffer;
auto size_thread = (Dv + 16)*rk2/nth_k*sizeof(float);
auto size_thread = (Dv + 16)*nq_per_thread*sizeof(float);
auto result_buffer = work;
if (nq_this_thread > 0) {
//if (ith > 0) return true;
//printf("=============== Dk = %d, Dv = %d\n", Dk, Dv);
//for (ith = 0; ith < nth; ++ith) {
int ith_k = ith%gcd_k;
int ith_q = ith/gcd_k;
//printf("Thread[%2d]: nstep_k=%d, gcd_k=%d, nth_k=%d, ith_k=%d, ith_q=%d\n", ith, nstep_k, gcd_k, nth_k, ith_k, ith_q);
auto kth = (const char *)k + ith_k*(nek1/gcd_k)*stride_k;
auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v;
auto qth = (const char *)q + ith_q*(rk2/nth_k)*nbq2;
auto q_offset = ith_q < ith_mid ? ith_q*nq_per_thread*nbq2 : (ith_mid*nq_per_thread + (ith_q - ith_mid)*nq_this_thread)*nbq2;
auto qth = (const char *)q + q_offset;
auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here
// Each thread will produce a result of size Dv*(rk2/nth_k)*sizeof(float)
@@ -85,12 +98,13 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
// writing onto the same cache line.
auto work_this_thread = (float *)(result_buffer + ith*size_thread);
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, rk2/nth_k, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
Dk, Dv, nq_this_thread, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
(const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth,
scale, softcap,
work_this_thread, work_this_thread + (Dv+0)*rk2/nth_k, work_this_thread + (Dv+1)*rk2/nth_k)) return false;
work_this_thread, work_this_thread + (Dv+0)*nq_this_thread, work_this_thread + (Dv+1)*nq_this_thread)) return false;
//}
}
barrier(barrier_data);
// There are nek1/gcd_k contributions for each j that we need to sum up
@@ -98,16 +112,31 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
//for (ith = 0; ith < nth; ++ith) {
// TODO: simdify this
// TODO: if nth > rk2, have threads process portions of the rows instead of entire rows as it is now
for (int j = ith; j < rk2; j += nth) {
auto Racc = qkv + j*nb1/sizeof(float);
float M = -INFINITY, S = 0;
// This row was computed by threads j/(rk2/nth_k)*gcd_k...j/(rk2/nth_k)*gcd_k+gcd_k-1
int jth_first = j/(rk2/nth_k)*gcd_k;
int jj = j%(rk2/nth_k);
int jth_first, jj, nq_this_j;
// j = 0....5 -> jth_first = 0, jj = 0...5
// j = 6...10 -> jth_first = 8, jj = 0...4
// j = 11...15 -> jth_first = 16, jj = 0...4
if (j < j_mid) {
jth_first = j/nq_per_thread;
jj = j%nq_per_thread;
nq_this_j = nq_per_thread;
} else {
jth_first = ith_mid + (j - j_mid)/(nq_per_thread-1);
jj = (j - j_mid)%(nq_per_thread-1);
nq_this_j = nq_per_thread - 1;
}
jth_first *= gcd_k;
//int jth_first = j/(rk2/nth_k)*gcd_k;
//int jj = j%(rk2/nth_k);
for (int jth = jth_first; jth < jth_first + gcd_k; ++jth) {
auto R = (const float *)(result_buffer + jth*size_thread);
auto Mj = R + Dv*rk2/nth_k;
auto Sj = Mj + rk2/nth_k;
auto Mj = R + Dv*nq_this_j;
auto Sj = Mj + nq_this_j;
R += jj*Dv;
if (Mj[jj] == -INFINITY) continue;
if (Mj[jj] > M) {
@@ -157,7 +186,7 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
//}
return true;
}
//}
}
printf("%s: not using fast path: rk2 = %d, nek1 = %d, gcd_k = %d nth_k = %d\n", __func__, rk2, nek1, gcd_k, nth/gcd_k);
}