mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-28 10:21:48 +00:00
Minor tweak
This commit is contained in:
@@ -143,19 +143,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|||||||
|
|
||||||
const int chunks_per_row = D2 / h2_per_chunk;
|
const int chunks_per_row = D2 / h2_per_chunk;
|
||||||
|
|
||||||
|
int k0_start = 0;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4, WARP_SIZE/8, WARP_SIZE/16}) {
|
for (int stride_k = WARP_SIZE; stride_k > WARP_SIZE/32; stride_k >>= 1) {
|
||||||
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
|
||||||
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
||||||
const int stride_i = WARP_SIZE / stride_k;
|
|
||||||
|
|
||||||
if (k0_start == k0_stop) {
|
if (k0_start == k0_stop) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int stride_i = WARP_SIZE / stride_k;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
||||||
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
const int i = i0 + threadIdx.y*stride_i + threadIdx.x / stride_k;
|
||||||
|
|
||||||
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
||||||
break;
|
break;
|
||||||
@@ -168,6 +168,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|||||||
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
|
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
k0_start = k0_stop;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
|
static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
|
||||||
|
|||||||
Reference in New Issue
Block a user