mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
The persistent bwd group kernel remapped block->CU as bx/8 + (bx%8)*32, which is only bijective for N=256 -> wrong gradients when ROCM_FLASH_ATTN_CU_NUM is set to any other value. Replace with flat blockIdx (bx + by*gridDim.x + bz*gridDim.x*gridDim.y), a bijection for any 1D/3D grid. Also add get_persistent_grid_override() honoring ROCM_FLASH_ATTN_GRID_Y/Z to reshape the persistent grid for 1D-vs-3D scheduling experiments.