Minor tweak

This commit is contained in:
Iwan Kawrakow
2025-05-07 09:07:34 +03:00
parent 53e7e7790e
commit 1982beb005

View File

@@ -143,19 +143,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
const int chunks_per_row = D2 / h2_per_chunk;
int k0_start = 0;
#pragma unroll
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4, WARP_SIZE/8, WARP_SIZE/16}) {
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
for (int stride_k = WARP_SIZE; stride_k > WARP_SIZE/32; stride_k >>= 1) {
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
const int stride_i = WARP_SIZE / stride_k;
if (k0_start == k0_stop) {
continue;
}
const int stride_i = WARP_SIZE / stride_k;
#pragma unroll
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
const int i = i0 + threadIdx.y*stride_i + threadIdx.x / stride_k;
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
break;
@@ -168,6 +168,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
}
}
k0_start = k0_stop;
}
} else {
static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");