[CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm (#2540)

* [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm

* Update rmsnorm host reference

* Update tree reduction of rmsnorm for reference host

* Fix cross warp for m > 1 cases

* Add RMSNorm model selectable option for host reference

* Fix save_unquant cases

* Update reference rmsnorm forward function to use enum for model sensitivity

* Update reference rmsnorm calculation for model sensitivity

* Fix m warp for layernorm

* Adjust parameter of reference for twoPass

* Fix clang format

* Run clang-format-overwrite.sh to fix formating issue

* fix clang format

---------

Co-authored-by: MHYang <mengyang@amd.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
ClementLinCF
2025-10-14 02:52:37 +08:00
committed by GitHub
parent fc2a121c44
commit e1b0bdfbfa
7 changed files with 217 additions and 67 deletions

View File

@@ -75,6 +75,39 @@ struct layernorm2d_fwd_traits_
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return total_warps;
}
else
{
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;