mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm (#2540)
* [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm * Update rmsnorm host reference * Update tree reduction of rmsnorm for reference host * Fix cross warp for m > 1 cases * Add RMSNorm model selectable option for host reference * Fix save_unquant cases * Update reference rmsnorm forward function to use enum for model sensitivity * Update reference rmsnorm calculation for model sensitivity * Fix m warp for layernorm * Adjust parameter of reference for twoPass * Fix clang format * Run clang-format-overwrite.sh to fix formating issue * fix clang format --------- Co-authored-by: MHYang <mengyang@amd.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -75,6 +75,39 @@ struct rmsnorm2d_fwd_traits_
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
using UnquantYDataType = ck_tile::remove_cvref_t<UnquantYDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps;
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
@@ -605,15 +638,15 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
total_blob = list()
|
||||
|
||||
for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive
|
||||
current_trait_dict = h_trait_dicts[model_sensitive_flag]
|
||||
for hs_key in current_trait_dict:
|
||||
hs = current_trait_dict[hs_key]
|
||||
hs = current_trait_dict[hs_key]
|
||||
current_n = hs_key
|
||||
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
|
||||
prec_i, prec_o = dtype.split(',')
|
||||
|
||||
@@ -70,16 +70,16 @@ template <typename InDataType,
|
||||
bool SaveUnquant>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int fused_add = arg_parser.get_int("fadd");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
const int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int fused_add = arg_parser.get_int("fadd");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
|
||||
|
||||
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
|
||||
if(x_stride < 0)
|
||||
@@ -196,6 +196,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
if(n > 8192)
|
||||
{
|
||||
use_model_sensitive_rmsnorm = 0;
|
||||
}
|
||||
|
||||
std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
|
||||
<< ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush;
|
||||
@@ -297,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const int N = acc_.mDesc.get_lengths()[1];
|
||||
for(int n_ = 0; n_ < N; ++n_)
|
||||
{
|
||||
o_unquant_(m_, n_) = ck_tile::type_convert<OutDataType>(acc_(m_, n_));
|
||||
o_unquant_(m_, n_) = ck_tile::type_convert<UnquantYDataType>(acc_(m_, n_));
|
||||
}
|
||||
|
||||
dquant_functor(m_, o_, acc_);
|
||||
@@ -316,7 +321,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
invRms_host_ref,
|
||||
unquant_y_host_ref,
|
||||
epsilon,
|
||||
default_and_dquant_functor);
|
||||
default_and_dquant_functor,
|
||||
use_model_sensitive_rmsnorm);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -331,7 +337,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
invRms_host_ref,
|
||||
unquant_y_host_ref,
|
||||
epsilon,
|
||||
dquant_functor);
|
||||
dquant_functor,
|
||||
use_model_sensitive_rmsnorm);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -343,7 +350,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
YDataType,
|
||||
InvRmsDataType,
|
||||
ck_tile::null_type>(
|
||||
x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_null, epsilon);
|
||||
x_host,
|
||||
gamma_host,
|
||||
y_host_ref,
|
||||
invRms_host_ref,
|
||||
unquant_y_null,
|
||||
epsilon,
|
||||
ck_tile::reference_rmsnorm2d_default_epilogue{},
|
||||
use_model_sensitive_rmsnorm);
|
||||
}
|
||||
|
||||
y_buf.FromDevice(y_host_dev.data());
|
||||
@@ -354,6 +368,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
y_residual_buf.FromDevice(y_residual_host_dev.data());
|
||||
}
|
||||
|
||||
if constexpr(SaveUnquant)
|
||||
{
|
||||
unquant_y_buf.FromDevice(unquant_y_host_dev.data());
|
||||
}
|
||||
|
||||
auto [rtol, atol] = get_elimit<YDataType>();
|
||||
if(x_stride == n)
|
||||
{
|
||||
|
||||
@@ -1,49 +1,85 @@
|
||||
#!/bin/sh
|
||||
#!/bin/bash
|
||||
|
||||
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
|
||||
|
||||
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"\
|
||||
"-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
for fadd in "0" "1"; do
|
||||
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
|
||||
for s in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192
|
||||
done
|
||||
done
|
||||
done
|
||||
total=0
|
||||
valid=0
|
||||
|
||||
run_case() {
|
||||
cmd="$EXE -prec_i=$1 -fadd=$2 -s=$3 $4 -m=$5 -n=$6 $7"
|
||||
echo "[CMD] $cmd"
|
||||
output=$($cmd 2>&1)
|
||||
echo "$output"
|
||||
if echo "$output" | grep -q "valid:y"; then
|
||||
valid=$((valid + 1))
|
||||
fi
|
||||
total=$((total + 1))
|
||||
}
|
||||
|
||||
fquant_list=(
|
||||
""
|
||||
"-fquant=1 -prec_o=int8"
|
||||
"-fquant=2 -prec_o=int8"
|
||||
"-fquant=1 -prec_o=fp8"
|
||||
"-fquant=2 -prec_o=fp8"
|
||||
"-fquant=1 -prec_o=int8 -save_unquant=1"
|
||||
"-fquant=2 -prec_o=int8 -save_unquant=1"
|
||||
"-fquant=1 -prec_o=fp8 -save_unquant=1"
|
||||
"-fquant=2 -prec_o=fp8 -save_unquant=1"
|
||||
)
|
||||
|
||||
m_n_list=(
|
||||
"99 13" "17 16" "1 100" "4 128" "80 127"
|
||||
"7 599" "19 512" "11 510" "91 636"
|
||||
"31 1024" "8 1501" "3 1826" "5 2040"
|
||||
"7 2734" "1 3182" "9 4096" "3 8192"
|
||||
)
|
||||
|
||||
### Add special stride test ###
|
||||
m_n_stride_list=(
|
||||
"22 255 -x_stride=256 -xr_stride=256 -y_stride=256 -yr_stride=256"
|
||||
"33 313 -x_stride=1000 -xr_stride=1000 -y_stride=1000 -yr_stride=1000"
|
||||
"171 676 -x_stride=818 -xr_stride=818 -y_stride=818 -yr_stride=818"
|
||||
"12 768 -x_stride=800 -xr_stride=800 -y_stride=800 -yr_stride=800"
|
||||
"100 766 -x_stride=812 -xr_stride=812 -y_stride=812 -yr_stride=812"
|
||||
"64 1000 -x_stride=1004 -xr_stride=1004 -y_stride=1004 -yr_stride=1004"
|
||||
)
|
||||
|
||||
for fquant in "${fquant_list[@]}"; do
|
||||
for pr_i in "fp16" "bf16"; do
|
||||
for fadd in "0" "1"; do
|
||||
for s in "0" "1"; do
|
||||
for pair in "${m_n_list[@]}"; do
|
||||
m=$(echo $pair | cut -d ' ' -f1)
|
||||
n=$(echo $pair | cut -d ' ' -f2)
|
||||
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" ""
|
||||
done
|
||||
|
||||
### Running tests with stride ###
|
||||
for triple in "${m_n_stride_list[@]}"; do
|
||||
m=$(echo $triple | cut -d ' ' -f1)
|
||||
n=$(echo $triple | cut -d ' ' -f2)
|
||||
stride_args=$(echo $triple | cut -d ' ' -f3-)
|
||||
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "$stride_args"
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
# The following cases uses two pass pipeline which doesn't support quant epilogue.
|
||||
for fquant in ""
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
for fadd in "0" "1"; do
|
||||
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
|
||||
for s in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547
|
||||
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
|
||||
done
|
||||
done
|
||||
done
|
||||
# Special two-pass only
|
||||
for pr_i in "fp16" "bf16"; do
|
||||
for fadd in "0" "1"; do
|
||||
for s in "0" "1"; do
|
||||
run_case "$pr_i" "$fadd" "$s" "" "1" "10547" ""
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
# Summary
|
||||
echo "=============================="
|
||||
echo "Total cases: $total"
|
||||
echo "Valid cases: $valid"
|
||||
accuracy=$(awk "BEGIN {printf \"%.2f\", ($valid / $total) * 100}")
|
||||
echo "Accuracy: $accuracy%"
|
||||
echo "=============================="
|
||||
|
||||
Reference in New Issue
Block a user