[CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm (#2540)

* [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm

* Update rmsnorm host reference

* Update tree reduction of rmsnorm for reference host

* Fix cross warp for m > 1 cases

* Add RMSNorm model selectable option for host reference

* Fix save_unquant cases

* Update reference rmsnorm forward function to use enum for model sensitivity

* Update reference rmsnorm calculation for model sensitivity

* Fix m warp for layernorm

* Adjust parameter of reference for twoPass

* Fix clang format

* Run clang-format-overwrite.sh to fix formating issue

* fix clang format

---------

Co-authored-by: MHYang <mengyang@amd.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
ClementLinCF
2025-10-14 02:52:37 +08:00
committed by GitHub
parent fc2a121c44
commit e1b0bdfbfa
7 changed files with 217 additions and 67 deletions

View File

@@ -1,49 +1,85 @@
#!/bin/sh
#!/bin/bash
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"\
"-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
for s in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192
done
done
done
total=0
valid=0
run_case() {
cmd="$EXE -prec_i=$1 -fadd=$2 -s=$3 $4 -m=$5 -n=$6 $7"
echo "[CMD] $cmd"
output=$($cmd 2>&1)
echo "$output"
if echo "$output" | grep -q "valid:y"; then
valid=$((valid + 1))
fi
total=$((total + 1))
}
fquant_list=(
""
"-fquant=1 -prec_o=int8"
"-fquant=2 -prec_o=int8"
"-fquant=1 -prec_o=fp8"
"-fquant=2 -prec_o=fp8"
"-fquant=1 -prec_o=int8 -save_unquant=1"
"-fquant=2 -prec_o=int8 -save_unquant=1"
"-fquant=1 -prec_o=fp8 -save_unquant=1"
"-fquant=2 -prec_o=fp8 -save_unquant=1"
)
m_n_list=(
"99 13" "17 16" "1 100" "4 128" "80 127"
"7 599" "19 512" "11 510" "91 636"
"31 1024" "8 1501" "3 1826" "5 2040"
"7 2734" "1 3182" "9 4096" "3 8192"
)
### Add special stride test ###
m_n_stride_list=(
"22 255 -x_stride=256 -xr_stride=256 -y_stride=256 -yr_stride=256"
"33 313 -x_stride=1000 -xr_stride=1000 -y_stride=1000 -yr_stride=1000"
"171 676 -x_stride=818 -xr_stride=818 -y_stride=818 -yr_stride=818"
"12 768 -x_stride=800 -xr_stride=800 -y_stride=800 -yr_stride=800"
"100 766 -x_stride=812 -xr_stride=812 -y_stride=812 -yr_stride=812"
"64 1000 -x_stride=1004 -xr_stride=1004 -y_stride=1004 -yr_stride=1004"
)
for fquant in "${fquant_list[@]}"; do
for pr_i in "fp16" "bf16"; do
for fadd in "0" "1"; do
for s in "0" "1"; do
for pair in "${m_n_list[@]}"; do
m=$(echo $pair | cut -d ' ' -f1)
n=$(echo $pair | cut -d ' ' -f2)
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" ""
done
### Running tests with stride ###
for triple in "${m_n_stride_list[@]}"; do
m=$(echo $triple | cut -d ' ' -f1)
n=$(echo $triple | cut -d ' ' -f2)
stride_args=$(echo $triple | cut -d ' ' -f3-)
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "$stride_args"
done
done
done
done
done
# The following cases uses two pass pipeline which doesn't support quant epilogue.
for fquant in ""
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
for s in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done
# Special two-pass only
for pr_i in "fp16" "bf16"; do
for fadd in "0" "1"; do
for s in "0" "1"; do
run_case "$pr_i" "$fadd" "$s" "" "1" "10547" ""
done
done
done
# Summary
echo "=============================="
echo "Total cases: $total"
echo "Valid cases: $valid"
accuracy=$(awk "BEGIN {printf \"%.2f\", ($valid / $total) * 100}")
echo "Accuracy: $accuracy%"
echo "=============================="