Quick hack to improve TG performance for SWA models (#692)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-08-15 16:43:04 +03:00
committed by GitHub
parent fc06bc9d27
commit 4239d259a6

View File

@@ -80,6 +80,29 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
int rk3 = neq3/nek3;
int rv3 = neq3/nev3;
int first_k = 0, last_k = nek1;
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) {
// This is a quick hack for SWA models.
// Given that the mask is the same for all layers, ideally we should determinbe the
// cache bounds once, and reuse for the whole graph. But even with this simple hack
// we get non-negligible performance gains for SWA models and long context.
auto umask = (const uint16_t *)mask;
for (; first_k < last_k; ++first_k) {
if (umask[first_k] == 0) break;
}
for (; last_k > first_k; --last_k) {
if (umask[last_k-1] == 0) break;
}
//printf("nek1 = %d, first = %d, last = %d\n", nek1, first, last);
if (last_k - first_k <= 3*nek1/4 && (last_k - first_k)%32 == 0) {
//printf("Reducing from %d to %d\n", nek1, last_k - first_k);
k = (const void *)((const char *)k + first_k*stride_k);
v = (const void *)((const char *)v + first_k*stride_v);
mask = (const void *)((const uint16_t *)mask + first_k);
nek1 = last_k - first_k;
}
}
int int_type_k = int_type_k_in;
auto work_buffer = work_buffer_in;
if (neq1 >= 8 || (rk2 >= 8 && nek2 > 1)) {