Files
composable_kernel/include
Ye Wang d7bb3b10cc [fmha-bwd] Flat cu_id remap for arbitrary CTA_NUM + grid Y/Z override env
The persistent bwd group kernel remapped block->CU as bx/8 + (bx%8)*32,
which is only bijective for N=256 -> wrong gradients when ROCM_FLASH_ATTN_CU_NUM
is set to any other value. Replace with flat blockIdx (bx + by*gridDim.x +
bz*gridDim.x*gridDim.y), a bijection for any 1D/3D grid. Also add
get_persistent_grid_override() honoring ROCM_FLASH_ATTN_GRID_Y/Z to reshape
the persistent grid for 1D-vs-3D scheduling experiments.
2026-06-11 13:19:51 +00:00
..