[CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm (#2540)

* [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm

* Update rmsnorm host reference

* Update tree reduction of rmsnorm for reference host

* Fix cross warp for m > 1 cases

* Add RMSNorm model selectable option for host reference

* Fix save_unquant cases

* Update reference rmsnorm forward function to use enum for model sensitivity

* Update reference rmsnorm calculation for model sensitivity

* Fix m warp for layernorm

* Adjust parameter of reference for twoPass

* Fix clang format

* Run clang-format-overwrite.sh to fix formating issue

* fix clang format

---------

Co-authored-by: MHYang <mengyang@amd.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
ClementLinCF
2025-10-14 02:52:37 +08:00
committed by GitHub
parent fc2a121c44
commit e1b0bdfbfa
7 changed files with 217 additions and 67 deletions

View File

@@ -75,6 +75,39 @@ struct rmsnorm2d_fwd_traits_
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
using UnquantYDataType = ck_tile::remove_cvref_t<UnquantYDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return total_warps;
}
else
{
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
@@ -605,15 +638,15 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)]
}
}
}
total_blob = list()
for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive
current_trait_dict = h_trait_dicts[model_sensitive_flag]
for hs_key in current_trait_dict:
hs = current_trait_dict[hs_key]
hs = current_trait_dict[hs_key]
current_n = hs_key
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
prec_i, prec_o = dtype.split(',')

View File

@@ -70,16 +70,16 @@ template <typename InDataType,
bool SaveUnquant>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
float epsilon = arg_parser.get_float("e");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
const int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
float epsilon = arg_parser.get_float("e");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(x_stride < 0)
@@ -196,6 +196,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
return base_str;
}();
if(n > 8192)
{
use_model_sensitive_rmsnorm = 0;
}
std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
<< ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush;
@@ -297,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const int N = acc_.mDesc.get_lengths()[1];
for(int n_ = 0; n_ < N; ++n_)
{
o_unquant_(m_, n_) = ck_tile::type_convert<OutDataType>(acc_(m_, n_));
o_unquant_(m_, n_) = ck_tile::type_convert<UnquantYDataType>(acc_(m_, n_));
}
dquant_functor(m_, o_, acc_);
@@ -316,7 +321,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
invRms_host_ref,
unquant_y_host_ref,
epsilon,
default_and_dquant_functor);
default_and_dquant_functor,
use_model_sensitive_rmsnorm);
}
else
{
@@ -331,7 +337,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
invRms_host_ref,
unquant_y_host_ref,
epsilon,
dquant_functor);
dquant_functor,
use_model_sensitive_rmsnorm);
}
}
else
@@ -343,7 +350,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
YDataType,
InvRmsDataType,
ck_tile::null_type>(
x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_null, epsilon);
x_host,
gamma_host,
y_host_ref,
invRms_host_ref,
unquant_y_null,
epsilon,
ck_tile::reference_rmsnorm2d_default_epilogue{},
use_model_sensitive_rmsnorm);
}
y_buf.FromDevice(y_host_dev.data());
@@ -354,6 +368,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_residual_buf.FromDevice(y_residual_host_dev.data());
}
if constexpr(SaveUnquant)
{
unquant_y_buf.FromDevice(unquant_y_host_dev.data());
}
auto [rtol, atol] = get_elimit<YDataType>();
if(x_stride == n)
{

View File

@@ -1,49 +1,85 @@
#!/bin/sh
#!/bin/bash
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"\
"-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
for s in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192
done
done
done
total=0
valid=0
run_case() {
cmd="$EXE -prec_i=$1 -fadd=$2 -s=$3 $4 -m=$5 -n=$6 $7"
echo "[CMD] $cmd"
output=$($cmd 2>&1)
echo "$output"
if echo "$output" | grep -q "valid:y"; then
valid=$((valid + 1))
fi
total=$((total + 1))
}
fquant_list=(
""
"-fquant=1 -prec_o=int8"
"-fquant=2 -prec_o=int8"
"-fquant=1 -prec_o=fp8"
"-fquant=2 -prec_o=fp8"
"-fquant=1 -prec_o=int8 -save_unquant=1"
"-fquant=2 -prec_o=int8 -save_unquant=1"
"-fquant=1 -prec_o=fp8 -save_unquant=1"
"-fquant=2 -prec_o=fp8 -save_unquant=1"
)
m_n_list=(
"99 13" "17 16" "1 100" "4 128" "80 127"
"7 599" "19 512" "11 510" "91 636"
"31 1024" "8 1501" "3 1826" "5 2040"
"7 2734" "1 3182" "9 4096" "3 8192"
)
### Add special stride test ###
m_n_stride_list=(
"22 255 -x_stride=256 -xr_stride=256 -y_stride=256 -yr_stride=256"
"33 313 -x_stride=1000 -xr_stride=1000 -y_stride=1000 -yr_stride=1000"
"171 676 -x_stride=818 -xr_stride=818 -y_stride=818 -yr_stride=818"
"12 768 -x_stride=800 -xr_stride=800 -y_stride=800 -yr_stride=800"
"100 766 -x_stride=812 -xr_stride=812 -y_stride=812 -yr_stride=812"
"64 1000 -x_stride=1004 -xr_stride=1004 -y_stride=1004 -yr_stride=1004"
)
for fquant in "${fquant_list[@]}"; do
for pr_i in "fp16" "bf16"; do
for fadd in "0" "1"; do
for s in "0" "1"; do
for pair in "${m_n_list[@]}"; do
m=$(echo $pair | cut -d ' ' -f1)
n=$(echo $pair | cut -d ' ' -f2)
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" ""
done
### Running tests with stride ###
for triple in "${m_n_stride_list[@]}"; do
m=$(echo $triple | cut -d ' ' -f1)
n=$(echo $triple | cut -d ' ' -f2)
stride_args=$(echo $triple | cut -d ' ' -f3-)
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "$stride_args"
done
done
done
done
done
# The following cases uses two pass pipeline which doesn't support quant epilogue.
for fquant in ""
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
for s in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done
# Special two-pass only
for pr_i in "fp16" "bf16"; do
for fadd in "0" "1"; do
for s in "0" "1"; do
run_case "$pr_i" "$fadd" "$s" "" "1" "10547" ""
done
done
done
# Summary
echo "=============================="
echo "Total cases: $total"
echo "Valid cases: $valid"
accuracy=$(awk "BEGIN {printf \"%.2f\", ($valid / $total) * 100}")
echo "Accuracy: $accuracy%"
echo "=============================="