From a66d1fc56250790cd36efea4fa7bfc45113e9ada Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 1 Sep 2024 12:47:50 +0300 Subject: [PATCH] Update FlashAttn comment --- ggml/src/iqk/iqk_mul_mat.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 3ce249e9..55dd016c 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6057,12 +6057,16 @@ inline __m256 v_tanh(__m256 x) { namespace { -// In some functions below we have two branches, one for small attention heads, and one for large -// My intent was to invoke the "small head" branch for attention heads with <= 128 elements, and the -// "large head" branch for attentions heads of size 256 (e.g., Gemma-2). But the "large head" branch -// has a bug that I don't see, so for now we use the "small head" branch for all head sizes. -// We definitely run out of SIMD registers for head siZe of 256, so performance is sub-optimal -// in this case. But it is still much better than mainline ggml and llama.cpp +// Some of the methods in FlashAttn have two identical implementations that only differ by +// one version using a loop over the template parameter q_step, while the other using a loop +// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot, +// but performance drops signficantly if I remove the version with fixed q_step iterations. +// We only instantiate FlashAttn with q_step = 1 and q_step = 4 or 8 (depending on head size D), +// so when we have to process Nq rows, we process q_step*(Nq/q_step) using fixed q_step loops, +// and use the variable nq version (with lower performance) only for the remaining i1...q_step-1 +// rows (if Nq is not a multiple of q_step). One could have made the number of q^T rows to +// process template parameter of such functions, but this would result in the compiler generating +// q_step-1 versions of these functions for us, which I though was too much with q_step = 8. template struct FlashAttn { static_assert(D%16 == 0 && D <= 256);