This commit is contained in:
Iwan Kawrakow
2025-04-24 09:25:13 +03:00
parent cd44692bc0
commit 9be8b490b1

View File

@@ -25,6 +25,24 @@ inline uint32_t simple_gcd(uint32_t a, uint32_t b) {
}
return a;
}
inline void accumulate_qkv(int Dv, float& M, float& S, float Mj, float Sj, float * Racc, const float * R) {
if (Mj == -INFINITY) return;
if (Mj > M) {
if (M == -INFINITY) {
std::memcpy(Racc, R, Dv*sizeof(float));
S = Sj;
} else {
float c = exp(M - Mj);
S = c*S + Sj;
for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i];
}
M = Mj;
} else {
float c = exp(Mj - M);
S += c*Sj;
for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
}
}
}
// TODO: get the ggml_type enum here without polution
@@ -151,22 +169,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
auto Mj = R + Dv*nq_this_j;
auto Sj = Mj + nq_this_j;
R += jj*Dv;
if (Mj[jj] == -INFINITY) continue;
if (Mj[jj] > M) {
if (M == -INFINITY) {
std::memcpy(Racc, R, Dv*sizeof(float));
S = Sj[jj];
} else {
float c = exp(M - Mj[jj]);
S = c*S + Sj[jj];
for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i];
}
M = Mj[jj];
} else {
float c = exp(Mj[jj] - M);
S += c*Sj[jj];
for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
}
accumulate_qkv(Dv, M, S, Mj[jj], Sj[jj], Racc, R);
}
float norm = S > 0 ? 1/S : 1;
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
@@ -177,7 +180,53 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
}
if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) {
auto result_size = (Dv + 16)*rk2*sizeof(float);
int gcd = simple_gcd(nek2, nth);
if (gcd > 1) {
int nth_g = nth/gcd;
int ith_g = ith%nth_g;
int nek1_32 = nek1/32;
int nek1_pt = (nek1_32 + nth_g - 1)/nth_g;
int ith_mid = nth_g;
if (nek1_pt*nth_g > nek1_32) {
ith_mid = nek1_32 - nth_g*(nek1_pt - 1);
}
nek1_pt *= 32;
int nek1_mid = ith_mid*nek1_pt;
int nek1_thread = ith_g < ith_mid ? nek1_pt : nek1_pt - 32;
for (int ik02 = ith/nth_g; ik02 < nek2; ik02 += gcd) {
int ik01 = ith_g < ith_mid ? ith_g*nek1_pt : nek1_mid + (ith_g - ith_mid)*nek1_thread;
auto this_result = (float *)((char *)work_buffer + (ik02*nth_g + ith_g)*result_size);
auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2);
auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2;
auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2;
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, rk2, nek1_thread, nbq2, stride_k, stride_v, 0, Dv,
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m,
scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
}
barrier(barrier_data);
for (int iq2 = ith; iq2 < neq2; iq2 += nth) {
int ik02 = iq2/rk2;
int il = iq2 - ik02*rk2;
auto Racc = qkv + iq2*nb1/sizeof(float);
float M = -INFINITY, S = 0;
for (int ig = 0; ig < nth_g; ++ig) {
int istep_k = ik02*nth_g + ig;
auto this_result = (float *)((char *)work_buffer + istep_k*result_size);
const float * R = this_result + il*Dv;
const float * Mj = this_result + Dv*rk2;
const float * Sj = Mj + rk2;
accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R);
}
float norm = S > 0 ? 1/S : 1;
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
}
return true;
}
int nth_k = nth/gcd;
int nek2_k = nek2/gcd;
int nchunk = nek2_k*nek1/32;
@@ -197,7 +246,6 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
//int nk = 32 * (nek2*nek1/(32*nth));
int nkk = (nek1 + nk - 1)/nk;
int nstep_k = nek2*nkk;
auto result_size = (Dv + 16)*rk2*sizeof(float);
//if (ith == 0) printf("rk2 = %d, nek1 = %d, nek2 = %d, nk = %d, nkk = %d, nstep_k = %d\n", (int)rk2, (int)nek1, (int)nek2, nk, nkk, nstep_k);
for (int istep_k = ith; istep_k < nstep_k; istep_k += nth) {
int ik02 = istep_k/nkk;
@@ -231,22 +279,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
const float * R = this_result + il*Dv;
const float * Mj = this_result + Dv*rk2;
const float * Sj = Mj + rk2;
if (Mj[il] == -INFINITY) continue;
if (Mj[il] > M) {
if (M == -INFINITY) {
std::memcpy(Racc, R, Dv*sizeof(float));
S = Sj[il];
} else {
float c = exp(M - Mj[il]);
S = c*S + Sj[il];
for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i];
}
M = Mj[il];
} else {
float c = exp(Mj[il] - M);
S += c*Sj[il];
for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
}
accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R);
}
float norm = S > 0 ? 1/S : 1;
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;