Add __syncthreads() to the new FA kernel

This commit is contained in:
Iwan Kawrakow
2025-05-20 15:34:27 +03:00
parent 2ec2229f2e
commit 651fed0848

View File

@@ -1093,6 +1093,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
}
// do we really need this?
__syncthreads();
// Write back combined meta data:
#pragma unroll
for (int imeta = 0; imeta < nmeta; ++imeta) {
@@ -1112,6 +1115,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
} else if (np > 1) {
// Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
// Therefore, all other warps also need to execute a __syncthreads().
// Otherwise the points at which warps synchronize with each other would become misaligned.
__syncthreads();
}
#pragma unroll