mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
Minor tweak
This commit is contained in:
@@ -143,19 +143,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
const int chunks_per_row = D2 / h2_per_chunk;
|
||||
|
||||
int k0_start = 0;
|
||||
#pragma unroll
|
||||
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4, WARP_SIZE/8, WARP_SIZE/16}) {
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
||||
for (int stride_k = WARP_SIZE; stride_k > WARP_SIZE/32; stride_k >>= 1) {
|
||||
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
||||
const int stride_i = WARP_SIZE / stride_k;
|
||||
|
||||
if (k0_start == k0_stop) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int stride_i = WARP_SIZE / stride_k;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
||||
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
const int i = i0 + threadIdx.y*stride_i + threadIdx.x / stride_k;
|
||||
|
||||
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
||||
break;
|
||||
@@ -168,6 +168,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
|
||||
}
|
||||
}
|
||||
k0_start = k0_stop;
|
||||
}
|
||||
} else {
|
||||
static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
|
||||
|
||||
Reference in New Issue
Block a user