mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Introduce tree reduction for BlockReduce2dCrossWarpSync (#2588)
* Introduce tree reduction for BlockReduce2dCrossWarpSync * Rename original impl to BlockReduce2dLinearCrossWarpSync * Replace warp_size with get_warp_size() --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -69,15 +69,6 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
|
||||
return BlockReduce2dCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dTreeCrossWarpSync()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2dTreeCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
|
||||
@@ -102,8 +102,8 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
|
||||
auto reduce_sum_func = ReduceOp::Add{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_tree_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dTreeCrossWarpSync<Problem>();
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
@@ -162,7 +162,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
|
||||
reduce_square_sum_func);
|
||||
}
|
||||
block_reduce2d_sync(square_sum, reduce_sum_func);
|
||||
block_reduce2d_tree_cross_warp_sync(square_sum, smem, reduce_sum_func);
|
||||
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
|
||||
|
||||
// compute inv-rms
|
||||
auto inv_rms = tile_elementwise_in(
|
||||
|
||||
Reference in New Issue
Block a user