This commit is contained in:
Iwan Kawrakow
2025-05-06 19:47:55 +03:00
parent 59a3e361a3
commit 53e7e7790e

View File

@@ -135,7 +135,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
if (use_cp_async) {
if constexpr (use_cp_async) {
const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
constexpr int preload = 64;
@@ -205,7 +205,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter");
if (use_cp_async) {
if constexpr (use_cp_async) {
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
constexpr int stride_j = nwarps * cols_per_warp;