mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
[rocm-libraries] ROCm/rocm-libraries#7016 (commit 2b73c00)
[CK] Fix RDNA3 FMHA tile-load paths ## Summary Fix CK tile FMHA paths needed for RDNA3/RDNA4 targets. ## Details This PR addresses RDNA-specific issues hit while enabling xFormers CK FMHA on gfx11/gfx12: - On RDNA3, update FMHA P tile handling so the layout consumed by the second GEMM matches the WMMA path. ## Testing Validated downstream with xFormers CK/FMHA on gfx1201/gfx1151. ```text pytest --import-mode=importlib -q \ tests/test_mem_eff_attention.py::test_forward \ tests/test_mem_eff_attention.py::test_backward \ tests/test_mem_eff_attention.py::test_dropout_ck 3844 passed, 5244 skipped, 26 warnings
This commit is contained in:
committed by
assistant-librarian[bot]
parent
424dfec6e4
commit
457f153b69
@@ -685,7 +685,17 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
auto p_cast =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap
|
||||
// softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM.
|
||||
auto p = make_static_distributed_tensor<PDataType>(
|
||||
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
|
||||
PermuteWarpGemmCToA(p, p_cast);
|
||||
#else
|
||||
const auto p = p_cast;
|
||||
#endif
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
|
||||
@@ -145,10 +145,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 WMMA V loads expect the LDS K-pack to match the warp GEMM K-per-thread;
|
||||
// clamping to 8 under-reserves LDS padding for K-per-thread 16 variants.
|
||||
return GetKVWarpGemmKPerThreadSize<Problem>();
|
||||
#else
|
||||
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
return 8;
|
||||
else
|
||||
return 4;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -512,8 +512,17 @@ struct BlockFmhaPipelineQSKSVS
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap
|
||||
// softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM.
|
||||
auto p = make_static_distributed_tensor<PDataType>(
|
||||
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
|
||||
PermuteWarpGemmCToA(
|
||||
p, cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)));
|
||||
#else
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#endif
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user