mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
[CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm (#2540)
* [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm * Update rmsnorm host reference * Update tree reduction of rmsnorm for reference host * Fix cross warp for m > 1 cases * Add RMSNorm model selectable option for host reference * Fix save_unquant cases * Update reference rmsnorm forward function to use enum for model sensitivity * Update reference rmsnorm calculation for model sensitivity * Fix m warp for layernorm * Adjust parameter of reference for twoPass * Fix clang format * Run clang-format-overwrite.sh to fix formating issue * fix clang format --------- Co-authored-by: MHYang <mengyang@amd.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -75,6 +75,39 @@ struct rmsnorm2d_fwd_traits_
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
using UnquantYDataType = ck_tile::remove_cvref_t<UnquantYDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps;
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
@@ -605,15 +638,15 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
total_blob = list()
|
||||
|
||||
for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive
|
||||
current_trait_dict = h_trait_dicts[model_sensitive_flag]
|
||||
for hs_key in current_trait_dict:
|
||||
hs = current_trait_dict[hs_key]
|
||||
hs = current_trait_dict[hs_key]
|
||||
current_n = hs_key
|
||||
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
|
||||
prec_i, prec_o = dtype.split(',')
|
||||
|
||||
Reference in New Issue
Block a user