diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a2bdc156..036bd8a8 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -21771,15 +21771,14 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa if (gcd_k > 1) { int nth_k = n_tasks/gcd_k; int rk2 = q->ne[2]/k->ne[2]; - if (rk2%nth_k == 0) { - size_t size = (Dv + 16)*rk2/nth_k*sizeof(float)*n_tasks; - if (ggml_is_quantized(k->type)) { - enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type; - size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); - size += q->ne[2]*row_size; - } - cur = MAX(cur, size); + int nq_per_thread = (rk2 + nth_k - 1)/nth_k; + size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks; + if (ggml_is_quantized(k->type)) { + enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type; + size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); + size += q->ne[2]*row_size; } + cur = MAX(cur, size); } } #endif diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index 3a21db2b..04264a06 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -64,19 +64,32 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, int gcd_k = simple_gcd(nstep_k, nth); if (gcd_k >= 1) { int nth_k = nth/gcd_k; - if (rk2%nth_k == 0) { - auto work = (char *)work_buffer; - auto size_thread = (Dv + 16)*rk2/nth_k*sizeof(float); - auto result_buffer = work; + int ith_k = ith%gcd_k; + int ith_q = ith/gcd_k; + // nth = 24, nek1 = 256, rk2 = 16 -> gcd_k = 8, nth_k = 3, nq_per_thread = 6 + // nq_per_thread*nth_k = 18 > 16 -> ith_mid = 1, nq_this_thread = 5 for ith_q >= 1, j_mid = 6 + int nq_per_thread = (rk2 + nth_k - 1)/nth_k; + int ith_mid = nth_k; + int nq_this_thread = nq_per_thread; + if (nq_per_thread*nth_k > rk2) { + // ith_mid*nq_per_thread + (nth_k - ith_mid)*(nq_per_thread - 1) = rk2 + // -> ith_mid = rk2 - nth_k*(nq_per_thread - 1) + ith_mid = rk2 - nth_k*(nq_per_thread - 1); + if (ith_q >= ith_mid) --nq_this_thread; + } + int j_mid = ith_mid*nq_per_thread; + auto work = (char *)work_buffer; + auto size_thread = (Dv + 16)*nq_per_thread*sizeof(float); + auto result_buffer = work; + if (nq_this_thread > 0) { //if (ith > 0) return true; //printf("=============== Dk = %d, Dv = %d\n", Dk, Dv); //for (ith = 0; ith < nth; ++ith) { - int ith_k = ith%gcd_k; - int ith_q = ith/gcd_k; //printf("Thread[%2d]: nstep_k=%d, gcd_k=%d, nth_k=%d, ith_k=%d, ith_q=%d\n", ith, nstep_k, gcd_k, nth_k, ith_k, ith_q); auto kth = (const char *)k + ith_k*(nek1/gcd_k)*stride_k; auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v; - auto qth = (const char *)q + ith_q*(rk2/nth_k)*nbq2; + auto q_offset = ith_q < ith_mid ? ith_q*nq_per_thread*nbq2 : (ith_mid*nq_per_thread + (ith_q - ith_mid)*nq_this_thread)*nbq2; + auto qth = (const char *)q + q_offset; auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here // Each thread will produce a result of size Dv*(rk2/nth_k)*sizeof(float) @@ -85,79 +98,95 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, // writing onto the same cache line. auto work_this_thread = (float *)(result_buffer + ith*size_thread); if (!iqk_flash_attn_impl(int_type_k, int_type_v, - Dk, Dv, rk2/nth_k, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv, + Dk, Dv, nq_this_thread, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv, (const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, scale, softcap, - work_this_thread, work_this_thread + (Dv+0)*rk2/nth_k, work_this_thread + (Dv+1)*rk2/nth_k)) return false; + work_this_thread, work_this_thread + (Dv+0)*nq_this_thread, work_this_thread + (Dv+1)*nq_this_thread)) return false; //} - barrier(barrier_data); - - // There are nek1/gcd_k contributions for each j that we need to sum up - // Thread i computed k/v (i%gcd_k)*(nek1/gcd_k) for j (i/gcd_k)*(rk2/nth_k)...((i/gcd_k)+1)*(rk2/nth_k) and results at offset i*size_thread - - //for (ith = 0; ith < nth; ++ith) { - // TODO: simdify this - for (int j = ith; j < rk2; j += nth) { - auto Racc = qkv + j*nb1/sizeof(float); - float M = -INFINITY, S = 0; - // This row was computed by threads j/(rk2/nth_k)*gcd_k...j/(rk2/nth_k)*gcd_k+gcd_k-1 - int jth_first = j/(rk2/nth_k)*gcd_k; - int jj = j%(rk2/nth_k); - for (int jth = jth_first; jth < jth_first + gcd_k; ++jth) { - auto R = (const float *)(result_buffer + jth*size_thread); - auto Mj = R + Dv*rk2/nth_k; - auto Sj = Mj + rk2/nth_k; - R += jj*Dv; - if (Mj[jj] == -INFINITY) continue; - if (Mj[jj] > M) { - if (M == -INFINITY) { - std::memcpy(Racc, R, Dv*sizeof(float)); - S = Sj[jj]; - } else { - float c = exp(M - Mj[jj]); - S = c*S + Sj[jj]; - for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; - } - M = Mj[jj]; - } else { - float c = exp(Mj[jj] - M); - S += c*Sj[jj]; - for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; - } - } - //int jth_q = j/(rk2/nth_k); - //int jj = j%(rk2/nth_k); - //printf("Thread[%2d]: working on %2d: jth_q=%d, jj=%d, suming %d...%d\n", ith, j, jth_q, jj, jth_q*(rk2/nth_k), jth_q*(rk2/nth_k)+rk2/nth_k-1); - //for (int j1 = 0; j1 < rk2/nth_k; ++j1) { - // auto R = (const float *)(result_buffer + (jth_q*(rk2/nth_k) + j1)*size_thread); - // auto Mj = R + Dv*rk2/nth_k; - // auto Sj = Mj + rk2/nth_k; - // R += jj*Dv; - // if (Mj[jj] == -INFINITY) continue; - // if (Mj[jj] > M) { - // if (M == -INFINITY) { - // std::memcpy(Racc, R, Dv*sizeof(float)); - // S = Sj[jj]; - // } else { - // float c = exp(M - Mj[jj]); - // S = c*S + Sj[jj]; - // for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; - // } - // M = Mj[jj]; - // } else { - // float c = exp(Mj[jj] - M); - // S += c*Sj[jj]; - // for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; - // } - //} - float norm = S > 0 ? 1/S : 1; - for (int i = 0; i < Dv; ++i) Racc[i] *= norm; - } - //} - return true; - } + barrier(barrier_data); + + // There are nek1/gcd_k contributions for each j that we need to sum up + // Thread i computed k/v (i%gcd_k)*(nek1/gcd_k) for j (i/gcd_k)*(rk2/nth_k)...((i/gcd_k)+1)*(rk2/nth_k) and results at offset i*size_thread + + //for (ith = 0; ith < nth; ++ith) { + // TODO: simdify this + // TODO: if nth > rk2, have threads process portions of the rows instead of entire rows as it is now + for (int j = ith; j < rk2; j += nth) { + auto Racc = qkv + j*nb1/sizeof(float); + float M = -INFINITY, S = 0; + // This row was computed by threads j/(rk2/nth_k)*gcd_k...j/(rk2/nth_k)*gcd_k+gcd_k-1 + int jth_first, jj, nq_this_j; + // j = 0....5 -> jth_first = 0, jj = 0...5 + // j = 6...10 -> jth_first = 8, jj = 0...4 + // j = 11...15 -> jth_first = 16, jj = 0...4 + if (j < j_mid) { + jth_first = j/nq_per_thread; + jj = j%nq_per_thread; + nq_this_j = nq_per_thread; + } else { + jth_first = ith_mid + (j - j_mid)/(nq_per_thread-1); + jj = (j - j_mid)%(nq_per_thread-1); + nq_this_j = nq_per_thread - 1; + } + jth_first *= gcd_k; + //int jth_first = j/(rk2/nth_k)*gcd_k; + //int jj = j%(rk2/nth_k); + for (int jth = jth_first; jth < jth_first + gcd_k; ++jth) { + auto R = (const float *)(result_buffer + jth*size_thread); + auto Mj = R + Dv*nq_this_j; + auto Sj = Mj + nq_this_j; + R += jj*Dv; + if (Mj[jj] == -INFINITY) continue; + if (Mj[jj] > M) { + if (M == -INFINITY) { + std::memcpy(Racc, R, Dv*sizeof(float)); + S = Sj[jj]; + } else { + float c = exp(M - Mj[jj]); + S = c*S + Sj[jj]; + for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; + } + M = Mj[jj]; + } else { + float c = exp(Mj[jj] - M); + S += c*Sj[jj]; + for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; + } + } + //int jth_q = j/(rk2/nth_k); + //int jj = j%(rk2/nth_k); + //printf("Thread[%2d]: working on %2d: jth_q=%d, jj=%d, suming %d...%d\n", ith, j, jth_q, jj, jth_q*(rk2/nth_k), jth_q*(rk2/nth_k)+rk2/nth_k-1); + //for (int j1 = 0; j1 < rk2/nth_k; ++j1) { + // auto R = (const float *)(result_buffer + (jth_q*(rk2/nth_k) + j1)*size_thread); + // auto Mj = R + Dv*rk2/nth_k; + // auto Sj = Mj + rk2/nth_k; + // R += jj*Dv; + // if (Mj[jj] == -INFINITY) continue; + // if (Mj[jj] > M) { + // if (M == -INFINITY) { + // std::memcpy(Racc, R, Dv*sizeof(float)); + // S = Sj[jj]; + // } else { + // float c = exp(M - Mj[jj]); + // S = c*S + Sj[jj]; + // for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; + // } + // M = Mj[jj]; + // } else { + // float c = exp(Mj[jj] - M); + // S += c*Sj[jj]; + // for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; + // } + //} + float norm = S > 0 ? 1/S : 1; + for (int i = 0; i < Dv; ++i) Racc[i] *= norm; + } + //} + return true; + + //} } printf("%s: not using fast path: rk2 = %d, nek1 = %d, gcd_k = %d nth_k = %d\n", __func__, rk2, nek1, gcd_k, nth/gcd_k); }