From 457f153b69472b84fa1819d384ee451632091467 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta <2500920+jammm@users.noreply.github.com> Date: Tue, 19 May 2026 13:42:43 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#7016 (commit 2b73c00) [CK] Fix RDNA3 FMHA tile-load paths ## Summary Fix CK tile FMHA paths needed for RDNA3/RDNA4 targets. ## Details This PR addresses RDNA-specific issues hit while enabling xFormers CK FMHA on gfx11/gfx12: - On RDNA3, update FMHA P tile handling so the layout consumed by the second GEMM matches the WMMA path. ## Testing Validated downstream with xFormers CK/FMHA on gfx1201/gfx1151. ```text pytest --import-mode=importlib -q \ tests/test_mem_eff_attention.py::test_forward \ tests/test_mem_eff_attention.py::test_backward \ tests/test_mem_eff_attention.py::test_dropout_ck 3844 passed, 5244 skipped, 26 warnings --- ...block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp | 12 +++++++++++- ...line_qr_ks_vs_whole_k_prefetch_default_policy.hpp | 6 ++++++ .../fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 9 +++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index 607ee70020..73b84594e7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -685,7 +685,17 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch __builtin_amdgcn_sched_barrier(0x00000001); - auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + auto p_cast = + cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); +#if defined(__gfx11__) + // gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap + // softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM. + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA(p, p_cast); +#else + const auto p = p_cast; +#endif __builtin_amdgcn_sched_barrier(0x00000001); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index e5e9e2333a..bc54f75e06 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -145,10 +145,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() { +#if defined(__gfx11__) + // gfx11 WMMA V loads expect the LDS K-pack to match the warp GEMM K-per-thread; + // clamping to 8 under-reserves LDS padding for K-per-thread 16 variants. + return GetKVWarpGemmKPerThreadSize(); +#else if constexpr(GetKVWarpGemmKPerThreadSize() >= 8) return 8; else return 4; +#endif } template diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 9fc3652f51..4eb5eb291a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -512,8 +512,17 @@ struct BlockFmhaPipelineQSKSVS block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); +#if defined(__gfx11__) + // gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap + // softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM. + auto p = make_static_distributed_tensor( + decltype(gemm_1)::template MakeABlockTileDistribution()); + PermuteWarpGemmCToA( + p, cast_tile(tile_elementwise_in(p_compute_element_func, p_compute))); +#else const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); +#endif __builtin_amdgcn_sched_barrier(0);