Fix gfx11 FMHA P tile layout remaps

This commit is contained in:
Aaryaman Vasishta
2026-05-06 01:03:01 +09:00
parent 198320a7c8
commit 63b5f8e2e8
6 changed files with 57 additions and 4 deletions

View File

@@ -1088,7 +1088,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
tile_elementwise_inout([&k_descale](auto& x) { x *= k_descale; }, s_acc);
}
const auto p = [&]() {
const auto p_cast = [&]() {
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
// STAGE 2, scale_s, add bias, mask, softmax
@@ -1423,6 +1423,15 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
tile_elementwise_in(p_compute_element_func, p_compute));
#endif
}();
#if defined(__gfx11__)
// gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap
// softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM.
auto p = make_static_distributed_tensor<PDataType>(
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
PermuteWarpGemmCToA(p, p_cast);
#else
const auto p = p_cast;
#endif
// STAGE 3, KV gemm
// KV_BLOCKSCALE: accumulate P*V into temporary tile before applying v_descale

View File

@@ -762,7 +762,7 @@ struct BlockFmhaPipelineQRKSVSAsync
randval_ptr, seq_offset, p_compute, randval_dram_window);
}
const auto p = [&]() {
const auto p_cast = [&]() {
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
// For fp32 to fp16,
// impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue,
@@ -777,6 +777,15 @@ struct BlockFmhaPipelineQRKSVSAsync
tile_elementwise_in(p_compute_element_func, p_compute));
#endif
}();
#if defined(__gfx11__)
// gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap
// softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM.
auto p = make_static_distributed_tensor<PDataType>(
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
PermuteWarpGemmCToA(p, p_cast);
#else
const auto p = p_cast;
#endif
float v_descale = 1.0f;
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)

View File

@@ -685,7 +685,17 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
__builtin_amdgcn_sched_barrier(0x00000001);
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
auto p_cast =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
#if defined(__gfx11__)
// gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap
// softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM.
auto p = make_static_distributed_tensor<PDataType>(
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
PermuteWarpGemmCToA(p, p_cast);
#else
const auto p = p_cast;
#endif
__builtin_amdgcn_sched_barrier(0x00000001);

View File

@@ -145,10 +145,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{
#if defined(__gfx11__)
// gfx11 WMMA V loads expect the LDS K-pack to match the warp GEMM K-per-thread;
// clamping to 8 under-reserves LDS padding for K-per-thread 16 variants.
return GetKVWarpGemmKPerThreadSize<Problem>();
#else
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
return 8;
else
return 4;
#endif
}
template <typename Problem>

View File

@@ -712,7 +712,17 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
__builtin_amdgcn_sched_barrier(0x00000001);
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
auto p_cast =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
#if defined(__gfx11__)
// gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap
// softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM.
auto p = make_static_distributed_tensor<PDataType>(
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
PermuteWarpGemmCToA(p, p_cast);
#else
const auto p = p_cast;
#endif
__builtin_amdgcn_sched_barrier(0x00000001);

View File

@@ -512,8 +512,17 @@ struct BlockFmhaPipelineQSKSVS
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
#if defined(__gfx11__)
// gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap
// softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM.
auto p = make_static_distributed_tensor<PDataType>(
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
PermuteWarpGemmCToA(
p, cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)));
#else
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
#endif
__builtin_amdgcn_sched_barrier(0);