Merge commit 'e1b0bdfbfa92f47006fdbced627c7470eacdea2b' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-13 19:10:56 +00:00
parent 713691609c
commit 63d907604b
7 changed files with 217 additions and 67 deletions

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
namespace ck_tile {
@@ -43,7 +44,9 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
HostTensor<InvRmsDataType>& invRms_m,
HostTensor<UnquantYDataType>& unquant_y_m_n,
ComputeDataType epsilon,
Epilogue epilogue_functor = {})
Epilogue epilogue_functor = {},
const int use_model_sensitive_rmsnorm =
static_cast<int>(Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL))
{
auto rmsnorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
@@ -68,7 +71,30 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
{
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
acc(m, n) = x * divisor * gamma;
if(use_model_sensitive_rmsnorm ==
static_cast<int>(
Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) // 0: for no specific model
{
acc(m, n) = x * divisor * gamma;
}
else if(use_model_sensitive_rmsnorm ==
static_cast<int>(Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE)) // 1: for T5-like model
{
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
{
const auto tmp0 = float_to_bf16<bf16_rounding_mode::standard>(x * divisor);
const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
type_convert<ComputeDataType>(tmp0) * gamma);
const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
acc(m, n) = rmsn_;
}
else
{
const auto tmp = type_convert<XDataType>(x * divisor);
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma;
acc(m, n) = rmsn_;
}
}
}
if constexpr(!std::is_same_v<UnquantYDataType, ck_tile::null_type>)
@@ -84,4 +110,5 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
std::thread::hardware_concurrency());
}
} // namespace ck_tile