This commit is contained in:
Iwan Kawrakow
2025-12-28 08:57:24 +00:00
parent 9b7d08eaa2
commit ba0e88a5e3
2 changed files with 3 additions and 3 deletions

View File

@@ -1102,8 +1102,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
const int jc = cols_per_warp == 8 ? tile_C_VKQ::get_j(col) : tile_C_VKQ_16::get_i(2*col);
//const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
//const int jc = cols_per_warp == 8 ? tile_C_VKQ::get_j(col) : tile_C_VKQ_16::get_i(2*col);
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
const float sink = sinks_f[jc % ncols2];
const float KQ_max_new = fmaxf(KQ_max[col], sink);