Try to improve for unusual number of heads/number of threads

This commit is contained in:
Iwan Kawrakow
2025-04-22 12:37:11 +03:00
parent 39714026fe
commit 9f310ea663
2 changed files with 38 additions and 4 deletions

View File

@@ -21794,7 +21794,24 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
}
if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
if (k->ne[2] > 1) {
int nk = MAX(1, 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)));
int gcd = simple_gcd(k->ne[2], n_tasks);
int nth_k = n_tasks/gcd;
int nek2_k = k->ne[2]/gcd;
int nchunk = nek2_k*k->ne[1]/32;
int npt = (nchunk + nth_k - 1)/nth_k;
int nk;
if (npt*nth_k == nchunk) {
nk = 32 * (k->ne[1]*k->ne[2]/(32*n_tasks));
} else {
//int nm = std::max(1, npt/8);
int nm = 1;
while (true) {
if (nm*4 >= npt) break;
nm *= 2;
}
nk = 32*nm;
}
//int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks));
int nstep_k = k->ne[2]*k->ne[1]/nk;
size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
size_t size = nstep_k*result_size;

View File

@@ -63,7 +63,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
int int_type_k = int_type_k_in;
auto work_buffer = work_buffer_in;
if (neq1 >= 8 || rk2 >= 4) {
if (neq1 >= 8 || rk2 >= 8) {
uint64_t row_size = 0;
work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size);
if (int_type_k != int_type_k_in) {
@@ -177,7 +177,24 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
}
if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) {
int nk = std::max(1, 32 * (nek2*nek1/(32*nth)));
int gcd = simple_gcd(nek2, nth);
int nth_k = nth/gcd;
int nek2_k = nek2/gcd;
int nchunk = nek2_k*nek1/32;
int npt = (nchunk + nth_k - 1)/nth_k;
int nk;
if (npt*nth_k == nchunk) {
nk = 32 * (nek2*nek1/(32*nth));
} else {
//int nm = std::max(1, npt/8);
int nm = 1;
while (true) {
if (nm*4 >= npt) break;
nm *= 2;
}
nk = 32*nm;
}
//int nk = 32 * (nek2*nek1/(32*nth));
int nkk = (nek1 + nk - 1)/nk;
int nstep_k = nek2*nkk;
auto result_size = (Dv + 16)*rk2*sizeof(float);
@@ -206,7 +223,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
int ik02 = iq2/rk2;
int il = iq2 - ik02*rk2;
auto Racc = qkv + iq2*nb1/sizeof(float);
std::memset(Racc, 0, Dv*sizeof(float));
//std::memset(Racc, 0, Dv*sizeof(float));
float M = -INFINITY, S = 0;
for (int ikk = 0; ikk < nkk; ++ikk) {
int istep_k = ik02*nkk + ikk;