Introduce tree reduction for BlockReduce2dCrossWarpSync (#2588)

* Introduce tree reduction for BlockReduce2dCrossWarpSync

* Rename original impl to BlockReduce2dLinearCrossWarpSync

* Replace warp_size with get_warp_size()

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
MHYangAMD
2025-10-22 14:41:35 +08:00
committed by GitHub
parent 37dff024c1
commit 5a27a97391
3 changed files with 120 additions and 150 deletions

View File

@@ -230,9 +230,121 @@ struct BlockReduce2dCrossWarpSync
template <typename YDistributedTensor_>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
using DataType = typename YDistributedTensor_::DataType;
// constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
using DataType = typename YDistributedTensor_::DataType;
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
// we need to store all data from every wave into smem
// e.g. 2x2 reduce along N
// -------------> reduce N
// | w0 | w1 | ___> | w01 |
// | w2 | w3 | | w23 |
//
// -> store data from every wave into LDS
//
//
// -------------> reduce N
// | w0 | w1 | w2 | w3 | -----> | w0123 |
//
// -> also store data from every wave into LDS
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
return num_warps * thread_buf_size * sizeof(DataType);
}
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
{
using DataType = typename YDistributedTensor_::DataType;
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id();
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
constexpr index_t num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
if constexpr(num_reduce_warps == 1)
return;
// Each warp's lane 0 writes its partial results to shared memory
const index_t smem_offset = warp_id;
if(lane_id == 0)
{
static_for<0, thread_buf_size, 1>{}([&](auto i) {
// Store the i-th element of this warp's thread_buffer into SMEM
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
});
}
block_sync_lds();
// We let each warp holds a duplication to do reduction.
const index_t local_warp_id = warp_id / num_reduce_warps;
const index_t local_smem_os = local_warp_id * num_reduce_warps;
static_for<0, thread_buf_size, 1>{}([&](auto i) {
DataType v[num_reduce_warps];
static_for<0, num_reduce_warps, 1>{}(
[&](auto idx) { v[idx] = smem_ptr[i * num_warps + local_smem_os + idx]; });
static_assert(is_power_of_two_integer(num_reduce_warps),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = integer_log2_floor(num_reduce_warps);
static_for<0, nstage, 1>{}([&](auto istage) {
constexpr index_t stride = 1 << istage.value;
static_for<0, num_reduce_warps, stride * 2>{}([&](auto idx_) {
constexpr index_t i0 = idx_();
constexpr index_t i1 = idx_ + stride;
if constexpr(i1 < num_reduce_warps)
{
v[i0] = reduce_func(v[i0], v[i1]);
}
});
});
y_tensor.get_thread_buffer()(i) = v[0];
});
}
};
template <typename Problem_, typename Policy_ = void>
struct BlockReduce2dLinearCrossWarpSync
{
using Problem = remove_cvref_t<Problem_>;
using BlockShape = typename Problem::BlockShape;
template <typename YDistributedTensor_>
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
{
constexpr index_t num_reduce_warps = [&]() {
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_warp = 0;
index_t len_ = 1;
static_for<0, NDimR, 1>{}([&](auto idim_r) {
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
len_ *= r_length;
}
});
return len_;
}();
return num_reduce_warps;
}
// return in byte
template <typename YDistributedTensor_>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
using DataType = typename YDistributedTensor_::DataType;
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
// we need to store all data from every wave into smem
@@ -300,7 +412,9 @@ struct BlockReduce2dCrossWarpSync
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
constexpr auto i_1 = number<i_1_n1 + 1>{};
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
v_local = reduce_func(v_local, v_remote);
// reduce
v_local = reduce_func(v_local, v_remote);
});
y_tensor.get_thread_buffer()(i_0) = v_local;
@@ -308,139 +422,4 @@ struct BlockReduce2dCrossWarpSync
}
};
template <typename Problem_, typename Policy_ = void>
struct BlockReduce2dTreeCrossWarpSync
{
using Problem = remove_cvref_t<Problem_>;
using BlockShape = typename Problem::BlockShape;
template <typename YDistributedTensor_>
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
{
constexpr index_t num_reduce_warps = [&]() {
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_warp = 0;
index_t len_ = 1;
static_for<0, NDimR, 1>{}([&](auto idim_r) {
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
len_ *= r_length;
}
});
return len_;
}();
return num_reduce_warps;
}
// return in byte
template <typename YDistributedTensor_>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
using DataType = typename YDistributedTensor_::DataType;
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
// we need to store all data from every wave into smem
// e.g. 2x2 reduce along N
// -------------> reduce N
// | w0 | w1 | ___> | w01 |
// | w2 | w3 | | w23 |
//
// -> store data from every wave into LDS
//
//
// -------------> reduce N
// | w0 | w1 | w2 | w3 | -----> | w0123 |
//
// -> also store data from every wave into LDS
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
return num_warps * thread_buf_size * sizeof(DataType);
}
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
{
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
using DataType = typename YDistributedTensor_::DataType;
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_lane = NDimP - 1;
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id();
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
constexpr index_t num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
if constexpr(num_reduce_warps == 1)
return;
// Each warp's lane 0 writes its partial results to shared memory
const index_t smem_offset = warp_id;
if(lane_id == 0)
{
static_for<0, thread_buf_size, 1>{}([&](auto i) {
// Store the i-th element of this warp's thread_buffer into SMEM
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
});
}
block_sync_lds();
// We let each warp holds a duplication to do reduction.
const index_t local_warp_id = warp_id / num_reduce_warps;
const index_t local_smem_os = local_warp_id * num_reduce_warps;
static_for<0, thread_buf_size, 1>{}([&](auto i) {
DataType v = 0;
if(lane_id < num_reduce_warps)
{
v = smem_ptr[i * num_warps + local_smem_os + lane_id];
}
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for<0, NDimR, 1>{}([&](auto idim_r) {
// FIXME: nasty to use does_p_own_r_
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
constexpr index_t lid_over_rid_derivative =
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
static_assert(is_power_of_two_integer(r_length),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = integer_log2_floor(r_length);
// reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) {
// pull data from remote lane
const auto o =
__shfl_xor(v, number<lid_over_rid_derivative << istage.value>{}.value);
// reduce
v = reduce_func(v, o);
});
}
});
y_tensor.get_thread_buffer()(i) = v;
});
}
};
} // namespace ck_tile