diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 20f9cd38..d4de298e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -25253,68 +25253,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread #if GGML_USE_IQK_MULMAT - size_t qsize = 0; - const struct ggml_tensor * q = node->src[0]; - const struct ggml_tensor * k = node->src[1]; - if (k->type == GGML_TYPE_Q8_0) { - qsize = ggml_nrows(k)*ggml_row_size(k->type, k->ne[0]); - } - int nstep_k = k->ne[1]/32; - if (nstep_k >= 4*n_tasks && q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1) { - size_t size_thread = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float); - size_t size = size_thread*n_tasks; - cur = MAX(cur, size+qsize); - } else { - if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) { - if (k->ne[2] > 1) { - int gcd = simple_gcd(k->ne[2], n_tasks); - int nth_k = n_tasks/gcd; - int nek2_k = k->ne[2]/gcd; - int nchunk = nek2_k*k->ne[1]/32; - int npt = (nchunk + nth_k - 1)/nth_k; - int nk; - if (npt*nth_k == nchunk) { - nk = 32 * (k->ne[1]*k->ne[2]/(32*n_tasks)); - } else { - //int nm = std::max(1, npt/8); - int nm = 1; - while (true) { - if (nm*4 >= npt) break; - nm *= 2; - } - nk = 32*nm; - } - //int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)); - nstep_k = k->ne[2]*k->ne[1]/nk; - size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float); - size_t size = nstep_k*result_size; - cur = MAX(cur, size+qsize); - } else { - nstep_k = k->ne[1]/32; - if (nstep_k >= n_tasks) { - size_t size_thread = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float); - size_t size = size_thread*n_tasks; - cur = MAX(cur, size+qsize); - } else { - int gcd_k = simple_gcd(nstep_k, n_tasks); - if (gcd_k > 1) { - int nth_k = n_tasks/gcd_k; - int rk2 = q->ne[2]/k->ne[2]; - int nq_per_thread = (rk2 + nth_k - 1)/nth_k; - size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks; - if (ggml_is_quantized(k->type)) { - enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type; - size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); - size += q->ne[2]*row_size; - } - cur = MAX(cur, size+qsize); - } - } - } - } else { - cur = MAX(cur, qsize); - } - } + size_t size = iqk_fa_work_buffer_size(node, n_tasks); + cur = MAX(cur, size); #endif } break; case GGML_OP_FLASH_ATTN_BACK: diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index a7cb30a0..47e55b0e 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -7,6 +7,7 @@ #include "iqk_config.h" #include "iqk_mul_mat.h" #include "iqk_flash_impl.h" +#include "ggml.h" #if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION @@ -45,6 +46,39 @@ inline void accumulate_qkv(int Dv, float& M, float& S, float Mj, float Sj, float } } +size_t iqk_fa_work_buffer_size(const struct ggml_tensor * dst, int nth) { + auto Q = dst->src[0]; + auto K = dst->src[1]; + auto V = dst->src[2]; + int rk2 = Q->ne[2]/K->ne[2]; + size_t size = 0; + if (K->type == GGML_TYPE_Q8_0 && (Q->ne[1] >= 8 || (rk2 >= 8 && K->ne[2] > 1))) { + size = ggml_row_size(GGML_TYPE_Q8_0, K->ne[0]) * K->ne[1]*K->ne[2]*K->ne[3]; + } + int nstep_k = K->ne[1]/32; + if (nstep_k >= 4*nth) { + auto size_thread = (V->ne[0] + 16)*rk2*sizeof(float); + size += size_thread*nth; + return size; + } + int gcd_k = simple_gcd(nstep_k, nth); + if (gcd_k >= 1) { + int nth_k = nth/gcd_k; + int nq_per_thread = (rk2 + nth_k - 1)/nth_k; + if (nq_per_thread > 1) { + auto size_thread = (V->ne[0] + 16)*nq_per_thread*sizeof(float); + size += size_thread*nth; + return size; + } + } + int rv2 = Q->ne[2] / V->ne[2]; + if (Q->ne[1] == 1 && Q->ne[3] == 1 && rk2 > 1 && rk2 == rv2 && K->ne[1]*K->ne[2] >= 32*nth) { + auto result_size = (V->ne[0] + 16)*rk2*sizeof(float); + size += result_size*nth; + } + return size; +} + // TODO: get the ggml_type enum here without polution // extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, @@ -145,8 +179,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float // I think it would also speed up things for GQA, but I'm leaving this for another day. if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nth >= 1 && nek1/32 > 1 && nek2 == 1) { int nstep_k = nek1/32; - //if (ith >= nstep_k && ith >= rk2) return true; - if (nstep_k >= nth) { //4*nth) { + if (nstep_k >= 4*nth) { int nstep_k_per_thread = (nstep_k + nth - 1)/nth; int ith_mid = nth; int nstep_k_this_thread = nstep_k_per_thread; @@ -169,22 +202,18 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float auto size_thread = (Dv + 16)*rk2*sizeof(float); auto result_buffer = work; auto work_this_thread = (float *)(result_buffer + ith*size_thread); - //if (nstep_k_this_thread > 0) { if (!iqk_flash_attn_impl(int_type_k, int_type_v, Dk, Dv, rk2, nstep_k_this_thread, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv, (const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, nullptr, 0, scale, softcap, work_this_thread, work_this_thread + (Dv+0)*rk2, work_this_thread + (Dv+1)*rk2)) return false; - //} barrier(barrier_data); - //int nhave = std::min(nstep_k, nth); for (int j = ith; j < rk2; j += nth) { auto Racc = qkv + j*nb1/sizeof(float); float M = -INFINITY, S = 0; for (int jth = 0; jth < nth; ++jth) { - //for (int jth = 0; jth < nhave; ++jth) { auto R = (const float *)(result_buffer + jth*size_thread); auto Mj = R + Dv*rk2; auto Sj = Mj + rk2; diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 3c1250e2..60c66d4a 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -7,6 +7,7 @@ #pragma once #include #include +#include #include "iqk_config.h" #ifdef __cplusplus extern "C" { @@ -37,6 +38,10 @@ IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int un IQK_API int iqk_dequant_type(int type, int Ny); +struct ggml_tensor; + +IQK_API size_t iqk_fa_work_buffer_size(const struct ggml_tensor * dst, int nthread); + typedef void (*barrier_t) (void *); IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,