From ece257f64565857d7a2e81d354225b54fcdad5dc Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 22 Mar 2025 10:55:06 +0200 Subject: [PATCH] Fix it for nth > rk2 --- ggml/src/iqk/iqk_flash_attn.cpp | 49 +++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index fecd818b..3a21db2b 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -65,20 +65,24 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, if (gcd_k >= 1) { int nth_k = nth/gcd_k; if (rk2%nth_k == 0) { + auto work = (char *)work_buffer; + auto size_thread = (Dv + 16)*rk2/nth_k*sizeof(float); + auto result_buffer = work; + //if (ith > 0) return true; + //printf("=============== Dk = %d, Dv = %d\n", Dk, Dv); + //for (ith = 0; ith < nth; ++ith) { int ith_k = ith%gcd_k; int ith_q = ith/gcd_k; + //printf("Thread[%2d]: nstep_k=%d, gcd_k=%d, nth_k=%d, ith_k=%d, ith_q=%d\n", ith, nstep_k, gcd_k, nth_k, ith_k, ith_q); auto kth = (const char *)k + ith_k*(nek1/gcd_k)*stride_k; auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v; auto qth = (const char *)q + ith_q*(rk2/nth_k)*nbq2; auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here - auto work = (char *)work_buffer; // Each thread will produce a result of size Dv*(rk2/nth_k)*sizeof(float) // In addition, we need M, S for the rk2/nth_k rows the thread is processing // => (Dv + 2)*rk2/nth_k*sizeof(float). We use (Dv + 16) instead to make sure threads are not // writing onto the same cache line. - auto size_thread = (Dv + 16)*rk2/nth_k*sizeof(float); - auto result_buffer = work; auto work_this_thread = (float *)(result_buffer + ith*size_thread); if (!iqk_flash_attn_impl(int_type_k, int_type_v, Dk, Dv, rk2/nth_k, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv, @@ -86,16 +90,22 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, scale, softcap, work_this_thread, work_this_thread + (Dv+0)*rk2/nth_k, work_this_thread + (Dv+1)*rk2/nth_k)) return false; + //} barrier(barrier_data); + // There are nek1/gcd_k contributions for each j that we need to sum up + // Thread i computed k/v (i%gcd_k)*(nek1/gcd_k) for j (i/gcd_k)*(rk2/nth_k)...((i/gcd_k)+1)*(rk2/nth_k) and results at offset i*size_thread + + //for (ith = 0; ith < nth; ++ith) { // TODO: simdify this for (int j = ith; j < rk2; j += nth) { auto Racc = qkv + j*nb1/sizeof(float); float M = -INFINITY, S = 0; - int jth_q = j/(rk2/nth_k); + // This row was computed by threads j/(rk2/nth_k)*gcd_k...j/(rk2/nth_k)*gcd_k+gcd_k-1 + int jth_first = j/(rk2/nth_k)*gcd_k; int jj = j%(rk2/nth_k); - for (int j1 = 0; j1 < rk2/nth_k; ++j1) { - auto R = (const float *)(result_buffer + (jth_q*(rk2/nth_k) + j1)*size_thread); + for (int jth = jth_first; jth < jth_first + gcd_k; ++jth) { + auto R = (const float *)(result_buffer + jth*size_thread); auto Mj = R + Dv*rk2/nth_k; auto Sj = Mj + rk2/nth_k; R += jj*Dv; @@ -116,13 +126,40 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; } } + //int jth_q = j/(rk2/nth_k); + //int jj = j%(rk2/nth_k); + //printf("Thread[%2d]: working on %2d: jth_q=%d, jj=%d, suming %d...%d\n", ith, j, jth_q, jj, jth_q*(rk2/nth_k), jth_q*(rk2/nth_k)+rk2/nth_k-1); + //for (int j1 = 0; j1 < rk2/nth_k; ++j1) { + // auto R = (const float *)(result_buffer + (jth_q*(rk2/nth_k) + j1)*size_thread); + // auto Mj = R + Dv*rk2/nth_k; + // auto Sj = Mj + rk2/nth_k; + // R += jj*Dv; + // if (Mj[jj] == -INFINITY) continue; + // if (Mj[jj] > M) { + // if (M == -INFINITY) { + // std::memcpy(Racc, R, Dv*sizeof(float)); + // S = Sj[jj]; + // } else { + // float c = exp(M - Mj[jj]); + // S = c*S + Sj[jj]; + // for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; + // } + // M = Mj[jj]; + // } else { + // float c = exp(Mj[jj] - M); + // S += c*Sj[jj]; + // for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; + // } + //} float norm = S > 0 ? 1/S : 1; for (int i = 0; i < Dv; ++i) Racc[i] *= norm; } + //} return true; } } + printf("%s: not using fast path: rk2 = %d, nek1 = %d, gcd_k = %d nth_k = %d\n", __func__, rk2, nek1, gcd_k, nth/gcd_k); } // I keep changing my mind what is the best strategy to split the threads when processing