[CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm (#2540)

* [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm

* Update rmsnorm host reference

* Update tree reduction of rmsnorm for reference host

* Fix cross warp for m > 1 cases

* Add RMSNorm model selectable option for host reference

* Fix save_unquant cases

* Update reference rmsnorm forward function to use enum for model sensitivity

* Update reference rmsnorm calculation for model sensitivity

* Fix m warp for layernorm

* Adjust parameter of reference for twoPass

* Fix clang format

* Run clang-format-overwrite.sh to fix formating issue

* fix clang format

---------

Co-authored-by: MHYang <mengyang@amd.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
ClementLinCF
2025-10-14 02:52:37 +08:00
committed by GitHub
parent fc2a121c44
commit e1b0bdfbfa
7 changed files with 217 additions and 67 deletions

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
namespace ck_tile {
@@ -43,7 +44,9 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
HostTensor<InvRmsDataType>& invRms_m,
HostTensor<UnquantYDataType>& unquant_y_m_n,
ComputeDataType epsilon,
Epilogue epilogue_functor = {})
Epilogue epilogue_functor = {},
const int use_model_sensitive_rmsnorm =
static_cast<int>(Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL))
{
auto rmsnorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
@@ -68,7 +71,30 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
{
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
acc(m, n) = x * divisor * gamma;
if(use_model_sensitive_rmsnorm ==
static_cast<int>(
Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) // 0: for no specific model
{
acc(m, n) = x * divisor * gamma;
}
else if(use_model_sensitive_rmsnorm ==
static_cast<int>(Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE)) // 1: for T5-like model
{
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
{
const auto tmp0 = float_to_bf16<bf16_rounding_mode::standard>(x * divisor);
const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
type_convert<ComputeDataType>(tmp0) * gamma);
const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
acc(m, n) = rmsn_;
}
else
{
const auto tmp = type_convert<XDataType>(x * divisor);
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma;
acc(m, n) = rmsn_;
}
}
}
if constexpr(!std::is_same_v<UnquantYDataType, ck_tile::null_type>)
@@ -84,4 +110,5 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
std::thread::hardware_concurrency());
}
} // namespace ck_tile