mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
The persistent bwd group kernel remapped block->CU as bx/8 + (bx%8)*32, which is only bijective for N=256 -> wrong gradients when ROCM_FLASH_ATTN_CU_NUM is set to any other value. Replace with flat blockIdx (bx + by*gridDim.x + bz*gridDim.x*gridDim.y), a bijection for any 1D/3D grid. Also add get_persistent_grid_override() honoring ROCM_FLASH_ATTN_GRID_Y/Z to reshape the persistent grid for 1D-vs-3D scheduling experiments.