Introduce tree reduction for BlockReduce2dCrossWarpSync (#2588)

* Introduce tree reduction for BlockReduce2dCrossWarpSync

* Rename original impl to BlockReduce2dLinearCrossWarpSync

* Replace warp_size with get_warp_size()

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
MHYangAMD
2025-10-22 14:41:35 +08:00
committed by GitHub
parent 37dff024c1
commit 5a27a97391
3 changed files with 120 additions and 150 deletions

View File

@@ -69,15 +69,6 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
return BlockReduce2dCrossWarpSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dTreeCrossWarpSync()
{
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dTreeCrossWarpSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{

View File

@@ -102,8 +102,8 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
auto reduce_sum_func = ReduceOp::Add{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_tree_cross_warp_sync =
Policy::template GetBlockReduce2dTreeCrossWarpSync<Problem>();
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
@@ -162,7 +162,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
reduce_square_sum_func);
}
block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_tree_cross_warp_sync(square_sum, smem, reduce_sum_func);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
// compute inv-rms
auto inv_rms = tile_elementwise_in(