Apply offfset to KQ_max in CUDA flash attention

This commit is contained in:
Kawrakow
2026-01-28 12:37:51 +00:00
parent 68cd52e583
commit 7aaed3f488
2 changed files with 4 additions and 4 deletions

View File

@@ -264,7 +264,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
#pragma unroll
for (int l = 0; l < tile_C_KQ::ne; ++l) {
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + FATTN_KQ_MAX_OFFSET);
}
}
@@ -319,7 +319,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#pragma unroll
for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
const int KQ_index = 2*t + (l/2) % 2;
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]);
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + FATTN_KQ_MAX_OFFSET);
}
}
}

View File

@@ -702,7 +702,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
#pragma unroll
for (int l = 0; l < tile_C_KQ::ne; ++l) {
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + FATTN_KQ_MAX_OFFSET);
}
}
@@ -756,7 +756,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#pragma unroll
for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
const int KQ_index = 2*t + (l/2) % 2;
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]);
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + FATTN_KQ_MAX_OFFSET);
}
}
}