mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-02 18:10:02 +00:00
Minor
This commit is contained in:
@@ -155,20 +155,22 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
|
||||
|
||||
if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) {
|
||||
int nk = 32 * (nek2*nek1/(32*nth));
|
||||
int nkk = nek1/nk;
|
||||
int nkk = (nek1 + nk - 1)/nk;
|
||||
int nstep_k = nek2*nkk;
|
||||
auto result_size = (Dv + 16)*rk2*sizeof(float);
|
||||
//if (ith == 0) printf("rk2 = %d, nek1 = %d, nek2 = %d, nk = %d, nkk = %d, nstep_k = %d\n", (int)rk2, (int)nek1, (int)nek2, nk, nkk, nstep_k);
|
||||
for (int istep_k = ith; istep_k < nstep_k; istep_k += nth) {
|
||||
int ik02 = istep_k/nkk;
|
||||
int ik01 = nk*(istep_k - ik02*nkk);
|
||||
int this_nk = ik01 + nk <= nek1 ? nk : nek1 - ik01;
|
||||
if (this_nk <= 0) break;
|
||||
auto this_result = (float *)((char *)work_buffer + istep_k*result_size);
|
||||
auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2);
|
||||
auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2;
|
||||
auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2;
|
||||
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
|
||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
||||
Dk, Dv, rk2, nk, nbq2, stride_k, stride_v, 0, Dv,
|
||||
Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv,
|
||||
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m,
|
||||
scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user