[rocm-libraries] ROCm/rocm-libraries#7016 (commit 2b73c00)

[CK] Fix RDNA3 FMHA tile-load paths

## Summary

Fix CK tile FMHA paths needed for RDNA3/RDNA4 targets.

## Details

This PR addresses RDNA-specific issues hit while enabling xFormers CK
FMHA on gfx11/gfx12:

- On RDNA3, update FMHA P tile handling so the layout consumed by the
second GEMM matches the WMMA path.

## Testing

Validated downstream with xFormers CK/FMHA on gfx1201/gfx1151.

```text
pytest --import-mode=importlib -q \
  tests/test_mem_eff_attention.py::test_forward \
  tests/test_mem_eff_attention.py::test_backward \
  tests/test_mem_eff_attention.py::test_dropout_ck

3844 passed, 5244 skipped, 26 warnings
This commit is contained in:
Aaryaman Vasishta
2026-05-19 13:42:43 +00:00
committed by assistant-librarian[bot]
parent 424dfec6e4
commit 457f153b69
3 changed files with 26 additions and 1 deletions

View File

@@ -685,7 +685,17 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
__builtin_amdgcn_sched_barrier(0x00000001);
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
auto p_cast =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
#if defined(__gfx11__)
// gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap
// softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM.
auto p = make_static_distributed_tensor<PDataType>(
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
PermuteWarpGemmCToA(p, p_cast);
#else
const auto p = p_cast;
#endif
__builtin_amdgcn_sched_barrier(0x00000001);

View File

@@ -145,10 +145,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{
#if defined(__gfx11__)
// gfx11 WMMA V loads expect the LDS K-pack to match the warp GEMM K-per-thread;
// clamping to 8 under-reserves LDS padding for K-per-thread 16 variants.
return GetKVWarpGemmKPerThreadSize<Problem>();
#else
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
return 8;
else
return 4;
#endif
}
template <typename Problem>

View File

@@ -512,8 +512,17 @@ struct BlockFmhaPipelineQSKSVS
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
#if defined(__gfx11__)
// gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap
// softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM.
auto p = make_static_distributed_tensor<PDataType>(
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
PermuteWarpGemmCToA(
p, cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)));
#else
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
#endif
__builtin_amdgcn_sched_barrier(0);