mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm (#2540)
* [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm * Update rmsnorm host reference * Update tree reduction of rmsnorm for reference host * Fix cross warp for m > 1 cases * Add RMSNorm model selectable option for host reference * Fix save_unquant cases * Update reference rmsnorm forward function to use enum for model sensitivity * Update reference rmsnorm calculation for model sensitivity * Fix m warp for layernorm * Adjust parameter of reference for twoPass * Fix clang format * Run clang-format-overwrite.sh to fix formating issue * fix clang format --------- Co-authored-by: MHYang <mengyang@amd.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -43,7 +44,9 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
|
||||
HostTensor<InvRmsDataType>& invRms_m,
|
||||
HostTensor<UnquantYDataType>& unquant_y_m_n,
|
||||
ComputeDataType epsilon,
|
||||
Epilogue epilogue_functor = {})
|
||||
Epilogue epilogue_functor = {},
|
||||
const int use_model_sensitive_rmsnorm =
|
||||
static_cast<int>(Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL))
|
||||
{
|
||||
auto rmsnorm2d_fwd_func = [&](auto m) {
|
||||
const int N = x_m_n.mDesc.get_lengths()[1];
|
||||
@@ -68,7 +71,30 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
|
||||
{
|
||||
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
|
||||
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
|
||||
acc(m, n) = x * divisor * gamma;
|
||||
if(use_model_sensitive_rmsnorm ==
|
||||
static_cast<int>(
|
||||
Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) // 0: for no specific model
|
||||
{
|
||||
acc(m, n) = x * divisor * gamma;
|
||||
}
|
||||
else if(use_model_sensitive_rmsnorm ==
|
||||
static_cast<int>(Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE)) // 1: for T5-like model
|
||||
{
|
||||
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
|
||||
{
|
||||
const auto tmp0 = float_to_bf16<bf16_rounding_mode::standard>(x * divisor);
|
||||
const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
|
||||
type_convert<ComputeDataType>(tmp0) * gamma);
|
||||
const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
|
||||
acc(m, n) = rmsn_;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto tmp = type_convert<XDataType>(x * divisor);
|
||||
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma;
|
||||
acc(m, n) = rmsn_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!std::is_same_v<UnquantYDataType, ck_tile::null_type>)
|
||||
@@ -84,4 +110,5 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
|
||||
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -400,11 +400,13 @@ struct BlockReduce2dTreeCrossWarpSync
|
||||
block_sync_lds();
|
||||
|
||||
// We let each warp holds a duplication to do reduction.
|
||||
const index_t local_warp_id = warp_id / num_reduce_warps;
|
||||
const index_t local_smem_os = local_warp_id * num_reduce_warps;
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
DataType v = 0;
|
||||
if(lane_id < num_reduce_warps)
|
||||
{
|
||||
v = smem_ptr[lane_id + i * num_warps];
|
||||
v = smem_ptr[i * num_warps + local_smem_os + lane_id];
|
||||
}
|
||||
|
||||
// cross-lane reduce for replication
|
||||
|
||||
@@ -146,7 +146,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
|
||||
// compute mean square each-thread->cross-lane->cross-warp
|
||||
auto square_sum = block_reduce2d.template MakeYBlockTile<decltype(acc)>();
|
||||
set_tile(square_sum, 0);
|
||||
if constexpr(Problem::BlockShape::Vector_N % 2 == 0)
|
||||
if constexpr((Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N) % 2 == 0)
|
||||
{
|
||||
sweep_tile(
|
||||
acc,
|
||||
@@ -179,7 +179,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
|
||||
if constexpr(std::is_same_v<YResidualDataType, ck_tile::bf16_t>)
|
||||
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
|
||||
{
|
||||
const auto tmp0 =
|
||||
float_to_bf16<bf16_rounding_mode::standard>(acc[idx] * inv_rms_[i_idx]);
|
||||
@@ -190,7 +190,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto tmp = type_convert<YResidualDataType>(acc[idx] * inv_rms_[i_idx]);
|
||||
const auto tmp = type_convert<XDataType>(acc[idx] * inv_rms_[i_idx]);
|
||||
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma_;
|
||||
rmsn(idx) = rmsn_;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user