From 4239d259a61b3ec752e5f776b7d016487bca6a34 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 15 Aug 2025 16:43:04 +0300 Subject: [PATCH] Quick hack to improve TG performance for SWA models (#692) Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_flash_attn.cpp | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index 19e6cd25..ccd81079 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -80,6 +80,29 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int rk3 = neq3/nek3; int rv3 = neq3/nev3; + int first_k = 0, last_k = nek1; + if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) { + // This is a quick hack for SWA models. + // Given that the mask is the same for all layers, ideally we should determinbe the + // cache bounds once, and reuse for the whole graph. But even with this simple hack + // we get non-negligible performance gains for SWA models and long context. + auto umask = (const uint16_t *)mask; + for (; first_k < last_k; ++first_k) { + if (umask[first_k] == 0) break; + } + for (; last_k > first_k; --last_k) { + if (umask[last_k-1] == 0) break; + } + //printf("nek1 = %d, first = %d, last = %d\n", nek1, first, last); + if (last_k - first_k <= 3*nek1/4 && (last_k - first_k)%32 == 0) { + //printf("Reducing from %d to %d\n", nek1, last_k - first_k); + k = (const void *)((const char *)k + first_k*stride_k); + v = (const void *)((const char *)v + first_k*stride_v); + mask = (const void *)((const uint16_t *)mask + first_k); + nek1 = last_k - first_k; + } + } + int int_type_k = int_type_k_in; auto work_buffer = work_buffer_in; if (neq1 >= 8 || (rk2 >= 8 && nek2 > 1)) {