mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm (#2540)
* [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm * Update rmsnorm host reference * Update tree reduction of rmsnorm for reference host * Fix cross warp for m > 1 cases * Add RMSNorm model selectable option for host reference * Fix save_unquant cases * Update reference rmsnorm forward function to use enum for model sensitivity * Update reference rmsnorm calculation for model sensitivity * Fix m warp for layernorm * Adjust parameter of reference for twoPass * Fix clang format * Run clang-format-overwrite.sh to fix formating issue * fix clang format --------- Co-authored-by: MHYang <mengyang@amd.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -1,49 +1,85 @@
|
||||
#!/bin/sh
|
||||
#!/bin/bash
|
||||
|
||||
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
|
||||
|
||||
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"\
|
||||
"-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
for fadd in "0" "1"; do
|
||||
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
|
||||
for s in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192
|
||||
done
|
||||
done
|
||||
done
|
||||
total=0
|
||||
valid=0
|
||||
|
||||
run_case() {
|
||||
cmd="$EXE -prec_i=$1 -fadd=$2 -s=$3 $4 -m=$5 -n=$6 $7"
|
||||
echo "[CMD] $cmd"
|
||||
output=$($cmd 2>&1)
|
||||
echo "$output"
|
||||
if echo "$output" | grep -q "valid:y"; then
|
||||
valid=$((valid + 1))
|
||||
fi
|
||||
total=$((total + 1))
|
||||
}
|
||||
|
||||
fquant_list=(
|
||||
""
|
||||
"-fquant=1 -prec_o=int8"
|
||||
"-fquant=2 -prec_o=int8"
|
||||
"-fquant=1 -prec_o=fp8"
|
||||
"-fquant=2 -prec_o=fp8"
|
||||
"-fquant=1 -prec_o=int8 -save_unquant=1"
|
||||
"-fquant=2 -prec_o=int8 -save_unquant=1"
|
||||
"-fquant=1 -prec_o=fp8 -save_unquant=1"
|
||||
"-fquant=2 -prec_o=fp8 -save_unquant=1"
|
||||
)
|
||||
|
||||
m_n_list=(
|
||||
"99 13" "17 16" "1 100" "4 128" "80 127"
|
||||
"7 599" "19 512" "11 510" "91 636"
|
||||
"31 1024" "8 1501" "3 1826" "5 2040"
|
||||
"7 2734" "1 3182" "9 4096" "3 8192"
|
||||
)
|
||||
|
||||
### Add special stride test ###
|
||||
m_n_stride_list=(
|
||||
"22 255 -x_stride=256 -xr_stride=256 -y_stride=256 -yr_stride=256"
|
||||
"33 313 -x_stride=1000 -xr_stride=1000 -y_stride=1000 -yr_stride=1000"
|
||||
"171 676 -x_stride=818 -xr_stride=818 -y_stride=818 -yr_stride=818"
|
||||
"12 768 -x_stride=800 -xr_stride=800 -y_stride=800 -yr_stride=800"
|
||||
"100 766 -x_stride=812 -xr_stride=812 -y_stride=812 -yr_stride=812"
|
||||
"64 1000 -x_stride=1004 -xr_stride=1004 -y_stride=1004 -yr_stride=1004"
|
||||
)
|
||||
|
||||
for fquant in "${fquant_list[@]}"; do
|
||||
for pr_i in "fp16" "bf16"; do
|
||||
for fadd in "0" "1"; do
|
||||
for s in "0" "1"; do
|
||||
for pair in "${m_n_list[@]}"; do
|
||||
m=$(echo $pair | cut -d ' ' -f1)
|
||||
n=$(echo $pair | cut -d ' ' -f2)
|
||||
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" ""
|
||||
done
|
||||
|
||||
### Running tests with stride ###
|
||||
for triple in "${m_n_stride_list[@]}"; do
|
||||
m=$(echo $triple | cut -d ' ' -f1)
|
||||
n=$(echo $triple | cut -d ' ' -f2)
|
||||
stride_args=$(echo $triple | cut -d ' ' -f3-)
|
||||
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "$stride_args"
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
# The following cases uses two pass pipeline which doesn't support quant epilogue.
|
||||
for fquant in ""
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
for fadd in "0" "1"; do
|
||||
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
|
||||
for s in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547
|
||||
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
|
||||
done
|
||||
done
|
||||
done
|
||||
# Special two-pass only
|
||||
for pr_i in "fp16" "bf16"; do
|
||||
for fadd in "0" "1"; do
|
||||
for s in "0" "1"; do
|
||||
run_case "$pr_i" "$fadd" "$s" "" "1" "10547" ""
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
# Summary
|
||||
echo "=============================="
|
||||
echo "Total cases: $total"
|
||||
echo "Valid cases: $valid"
|
||||
accuracy=$(awk "BEGIN {printf \"%.2f\", ($valid / $total) * 100}")
|
||||
echo "Accuracy: $accuracy%"
|
||||
echo "=============================="
|
||||
|
||||
Reference in New Issue
Block a user