Update FlashAttn comment

This commit is contained in:
Iwan Kawrakow
2024-09-01 12:47:50 +03:00
parent 94439ea73c
commit a66d1fc562

View File

@@ -6057,12 +6057,16 @@ inline __m256 v_tanh(__m256 x) {
namespace {
// In some functions below we have two branches, one for small attention heads, and one for large
// My intent was to invoke the "small head" branch for attention heads with <= 128 elements, and the
// "large head" branch for attentions heads of size 256 (e.g., Gemma-2). But the "large head" branch
// has a bug that I don't see, so for now we use the "small head" branch for all head sizes.
// We definitely run out of SIMD registers for head siZe of 256, so performance is sub-optimal
// in this case. But it is still much better than mainline ggml and llama.cpp
// Some of the methods in FlashAttn have two identical implementations that only differ by
// one version using a loop over the template parameter q_step, while the other using a loop
// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot,
// but performance drops signficantly if I remove the version with fixed q_step iterations.
// We only instantiate FlashAttn with q_step = 1 and q_step = 4 or 8 (depending on head size D),
// so when we have to process Nq rows, we process q_step*(Nq/q_step) using fixed q_step loops,
// and use the variable nq version (with lower performance) only for the remaining i1...q_step-1
// rows (if Nq is not a multiple of q_step). One could have made the number of q^T rows to
// process template parameter of such functions, but this would result in the compiler generating
// q_step-1 versions of these functions for us, which I though was too much with q_step = 8.
template <int D, int q_step, int k_step>
struct FlashAttn {
static_assert(D%16 == 0 && D <= 256);