mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Fix gfx11 FMHA P tile layout remaps
This commit is contained in:
@@ -1088,7 +1088,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
tile_elementwise_inout([&k_descale](auto& x) { x *= k_descale; }, s_acc);
|
||||
}
|
||||
|
||||
const auto p = [&]() {
|
||||
const auto p_cast = [&]() {
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
@@ -1423,6 +1423,15 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#endif
|
||||
}();
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap
|
||||
// softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM.
|
||||
auto p = make_static_distributed_tensor<PDataType>(
|
||||
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
|
||||
PermuteWarpGemmCToA(p, p_cast);
|
||||
#else
|
||||
const auto p = p_cast;
|
||||
#endif
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
// KV_BLOCKSCALE: accumulate P*V into temporary tile before applying v_descale
|
||||
|
||||
@@ -762,7 +762,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
randval_ptr, seq_offset, p_compute, randval_dram_window);
|
||||
}
|
||||
|
||||
const auto p = [&]() {
|
||||
const auto p_cast = [&]() {
|
||||
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
|
||||
// For fp32 to fp16,
|
||||
// impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue,
|
||||
@@ -777,6 +777,15 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#endif
|
||||
}();
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap
|
||||
// softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM.
|
||||
auto p = make_static_distributed_tensor<PDataType>(
|
||||
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
|
||||
PermuteWarpGemmCToA(p, p_cast);
|
||||
#else
|
||||
const auto p = p_cast;
|
||||
#endif
|
||||
|
||||
float v_descale = 1.0f;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
|
||||
@@ -685,7 +685,17 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
auto p_cast =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap
|
||||
// softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM.
|
||||
auto p = make_static_distributed_tensor<PDataType>(
|
||||
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
|
||||
PermuteWarpGemmCToA(p, p_cast);
|
||||
#else
|
||||
const auto p = p_cast;
|
||||
#endif
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
|
||||
@@ -145,10 +145,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 WMMA V loads expect the LDS K-pack to match the warp GEMM K-per-thread;
|
||||
// clamping to 8 under-reserves LDS padding for K-per-thread 16 variants.
|
||||
return GetKVWarpGemmKPerThreadSize<Problem>();
|
||||
#else
|
||||
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
return 8;
|
||||
else
|
||||
return 4;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -712,7 +712,17 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
auto p_cast =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap
|
||||
// softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM.
|
||||
auto p = make_static_distributed_tensor<PDataType>(
|
||||
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
|
||||
PermuteWarpGemmCToA(p, p_cast);
|
||||
#else
|
||||
const auto p = p_cast;
|
||||
#endif
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
|
||||
@@ -512,8 +512,17 @@ struct BlockFmhaPipelineQSKSVS
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 WMMA uses different lane layouts for GEMM C and GEMM A tiles, so remap
|
||||
// softmax P from GEMM0's C layout into GEMM1's A layout before the PV GEMM.
|
||||
auto p = make_static_distributed_tensor<PDataType>(
|
||||
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
|
||||
PermuteWarpGemmCToA(
|
||||
p, cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)));
|
||||
#else
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#endif
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user