mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Enhance RMSNorm Accuracy: New Pipeline Pass for Selectable Implementation (#2409)
* Add Rmsnorm2dFwdPipelineModelSensitiveT5Pass * Update rmsnorm2d_fwd_pipeline_model_sensitive_pass 1. Add BlockReduce2dTreeCrossWarpSync * Add Rmsnorm2dFusedModelSensitiveEnum * Update patch 1. Reverse generate.py 2. Remove comment in generate.py 3. Update tree cross warp reduce * Refactor RMSNorm model enum and introduce T5-like option * Update the n stage for cross warp reduce * Add new cmdline option in RMSNorm for new pipeline testing --------- Co-authored-by: Clement Lin <clement.lin@amd.com> Co-authored-by: ClementLinCF <162283536+ClementLinCF@users.noreply.github.com>
This commit is contained in:
@@ -272,4 +272,137 @@ struct BlockReduce2dCrossWarpSync
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockReduce2dTreeCrossWarpSync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using BlockShape = typename Problem::BlockShape;
|
||||
|
||||
template <typename YDistributedTensor_>
|
||||
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
|
||||
{
|
||||
constexpr index_t num_reduce_warps = [&]() {
|
||||
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
|
||||
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
|
||||
|
||||
constexpr index_t idim_p_warp = 0;
|
||||
|
||||
index_t len_ = 1;
|
||||
static_for<0, NDimR, 1>{}([&](auto idim_r) {
|
||||
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
|
||||
{
|
||||
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
|
||||
len_ *= r_length;
|
||||
}
|
||||
});
|
||||
return len_;
|
||||
}();
|
||||
return num_reduce_warps;
|
||||
}
|
||||
|
||||
// return in byte
|
||||
template <typename YDistributedTensor_>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
// we need to store all data from every wave into smem
|
||||
// e.g. 2x2 reduce along N
|
||||
// -------------> reduce N
|
||||
// | w0 | w1 | ___> | w01 |
|
||||
// | w2 | w3 | | w23 |
|
||||
//
|
||||
// -> store data from every wave into LDS
|
||||
//
|
||||
//
|
||||
// -------------> reduce N
|
||||
// | w0 | w1 | w2 | w3 | -----> | w0123 |
|
||||
//
|
||||
// -> also store data from every wave into LDS
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
return num_warps * thread_buf_size * sizeof(DataType);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
|
||||
{
|
||||
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
|
||||
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
|
||||
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
constexpr index_t num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
|
||||
|
||||
if constexpr(num_reduce_warps == 1)
|
||||
return;
|
||||
|
||||
// Each warp's lane 0 writes its partial results to shared memory
|
||||
const index_t smem_offset = warp_id;
|
||||
if(lane_id == 0)
|
||||
{
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
// Store the i-th element of this warp's thread_buffer into SMEM
|
||||
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
|
||||
});
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// We let each warp holds a duplication to do reduction.
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
DataType v = 0;
|
||||
if(lane_id < num_reduce_warps)
|
||||
{
|
||||
v = smem_ptr[lane_id + i * num_warps];
|
||||
}
|
||||
|
||||
// cross-lane reduce for replication
|
||||
// only reduce on R dimension correspond to lane
|
||||
// (lane id maps to this R dimension)
|
||||
static_for<0, NDimR, 1>{}([&](auto idim_r) {
|
||||
// FIXME: nasty to use does_p_own_r_
|
||||
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
|
||||
{
|
||||
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
|
||||
|
||||
constexpr index_t lid_over_rid_derivative =
|
||||
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
|
||||
|
||||
static_assert(is_power_of_two_integer(r_length),
|
||||
"wrong! only support power of 2 reduction");
|
||||
|
||||
constexpr index_t nstage = integer_log2_floor(r_length);
|
||||
|
||||
// reduction sweep forward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
// pull data from remote lane
|
||||
const auto o =
|
||||
__shfl_xor(v, number<lid_over_rid_derivative << istage.value>{}.value);
|
||||
|
||||
// reduce
|
||||
v = reduce_func(v, o);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
y_tensor.get_thread_buffer()(i) = v;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user