mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm (#2540)
* [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm * Update rmsnorm host reference * Update tree reduction of rmsnorm for reference host * Fix cross warp for m > 1 cases * Add RMSNorm model selectable option for host reference * Fix save_unquant cases * Update reference rmsnorm forward function to use enum for model sensitivity * Update reference rmsnorm calculation for model sensitivity * Fix m warp for layernorm * Adjust parameter of reference for twoPass * Fix clang format * Run clang-format-overwrite.sh to fix formating issue * fix clang format --------- Co-authored-by: MHYang <mengyang@amd.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -43,7 +44,9 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
|
||||
HostTensor<InvRmsDataType>& invRms_m,
|
||||
HostTensor<UnquantYDataType>& unquant_y_m_n,
|
||||
ComputeDataType epsilon,
|
||||
Epilogue epilogue_functor = {})
|
||||
Epilogue epilogue_functor = {},
|
||||
const int use_model_sensitive_rmsnorm =
|
||||
static_cast<int>(Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL))
|
||||
{
|
||||
auto rmsnorm2d_fwd_func = [&](auto m) {
|
||||
const int N = x_m_n.mDesc.get_lengths()[1];
|
||||
@@ -68,7 +71,30 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
|
||||
{
|
||||
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
|
||||
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
|
||||
acc(m, n) = x * divisor * gamma;
|
||||
if(use_model_sensitive_rmsnorm ==
|
||||
static_cast<int>(
|
||||
Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) // 0: for no specific model
|
||||
{
|
||||
acc(m, n) = x * divisor * gamma;
|
||||
}
|
||||
else if(use_model_sensitive_rmsnorm ==
|
||||
static_cast<int>(Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE)) // 1: for T5-like model
|
||||
{
|
||||
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
|
||||
{
|
||||
const auto tmp0 = float_to_bf16<bf16_rounding_mode::standard>(x * divisor);
|
||||
const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
|
||||
type_convert<ComputeDataType>(tmp0) * gamma);
|
||||
const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
|
||||
acc(m, n) = rmsn_;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto tmp = type_convert<XDataType>(x * divisor);
|
||||
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma;
|
||||
acc(m, n) = rmsn_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!std::is_same_v<UnquantYDataType, ck_tile::null_type>)
|
||||
@@ -84,4 +110,5 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
|
||||
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user