mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-19 20:54:36 +00:00
Quick hack to improve TG performance for SWA models (#692)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -80,6 +80,29 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
int rk3 = neq3/nek3;
|
||||
int rv3 = neq3/nev3;
|
||||
|
||||
int first_k = 0, last_k = nek1;
|
||||
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) {
|
||||
// This is a quick hack for SWA models.
|
||||
// Given that the mask is the same for all layers, ideally we should determinbe the
|
||||
// cache bounds once, and reuse for the whole graph. But even with this simple hack
|
||||
// we get non-negligible performance gains for SWA models and long context.
|
||||
auto umask = (const uint16_t *)mask;
|
||||
for (; first_k < last_k; ++first_k) {
|
||||
if (umask[first_k] == 0) break;
|
||||
}
|
||||
for (; last_k > first_k; --last_k) {
|
||||
if (umask[last_k-1] == 0) break;
|
||||
}
|
||||
//printf("nek1 = %d, first = %d, last = %d\n", nek1, first, last);
|
||||
if (last_k - first_k <= 3*nek1/4 && (last_k - first_k)%32 == 0) {
|
||||
//printf("Reducing from %d to %d\n", nek1, last_k - first_k);
|
||||
k = (const void *)((const char *)k + first_k*stride_k);
|
||||
v = (const void *)((const char *)v + first_k*stride_v);
|
||||
mask = (const void *)((const uint16_t *)mask + first_k);
|
||||
nek1 = last_k - first_k;
|
||||
}
|
||||
}
|
||||
|
||||
int int_type_k = int_type_k_in;
|
||||
auto work_buffer = work_buffer_in;
|
||||
if (neq1 >= 8 || (rk2 >= 8 && nek2 > 1)) {
|
||||
|
||||
Reference in New Issue
Block a user