Support for dtypes (fp8, bf8, bf16 and fp16) for the ck_tile/03_gemm example. (#1845)

* Support bf16/fb8/bf8 datatypes for ck_tile/gemm

* remove commented out code.

* Addressing code review comments and enabling universal_gemm for all the supported data types.

* Merge conflict resolution.

* Solve the memory pipeline compilation error. Merge with the new change of CShuffle

* finish the feature, pass the tests

* Fix the pipeline and add the benchmark script for other data types

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
kylasa
2025-02-06 14:07:38 -08:00
committed by Sam Wu
parent 9b5dfba242
commit ab5d027866
21 changed files with 598 additions and 88 deletions

View File

@@ -2,7 +2,8 @@
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
VALID=1
for b_matrix_layout in "R" "C"; do
for b_matrix_layout in "C"; do
for m in "64" "512" "1024" "2048"; do
for n in "512" "1024" "2048"; do
for k in "64" "512" "1024" "2048"; do

View File

@@ -0,0 +1,14 @@
#!/bin/sh
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
VALID=1
for b_matrix_layout in "C"; do
for m in "64" "512" "1024" "2048"; do
for n in "512" "1024" "2048"; do
for k in "64" "512" "1024" "2048"; do
$EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done

View File

@@ -2,10 +2,10 @@
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
VALID=1
for b_matrix_layout in "R" "C"; do
for m in "64" "512" "1024" "2048"; do
for b_matrix_layout in "C"; do
for m in "512" "1024" "2048" "4096"; do
for n in "512" "1024" "2048"; do
for k in "64" "512" "1024" "2048"; do
for k in "512" "1024" "2048"; do
$EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done

View File

@@ -0,0 +1,13 @@
#!/bin/sh
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
VALID=1
for b_matrix_layout in "C"; do
for m in "512" "1024" "2048" "4096"; do
for n in "512" "1024" "2048"; do
for k in "512" "1024" "2048"; do
$EXE -prec=bf16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done

View File

@@ -0,0 +1,13 @@
#!/bin/sh
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
VALID=1
for b_matrix_layout in "C"; do
for m in "512" "1024" "2048" "4096"; do
for n in "512" "1024" "2048"; do
for k in "512" "1024" "2048"; do
$EXE -prec=bf8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done

View File

@@ -0,0 +1,13 @@
#!/bin/sh
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
VALID=1
for b_matrix_layout in "C"; do
for m in "512" "1024" "2048" "4096"; do
for n in "512" "1024" "2048"; do
for k in "512" "1024" "2048"; do
$EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done

View File

@@ -7,22 +7,20 @@ export CK_REPEAT=1
COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
run_fp16_tests() {
for batch in 1 2; do
for m in 128 1024; do
for n in 128 2048; do
for k in 32 64; do
run_tests() {
for m in 128 1024; do
for n in 128 2048; do
for k in 64 128; do
$EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS
if [ $? -eq 0 ]; then
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully."
else
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
$EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS
if [ $? -eq 0 ]; then
echo "Success: Test with m=$m, n=$n, k=$k executed successfully."
else
echo "Error: Test with m=$m, n=$n, k=$k failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
done
done
done
done
@@ -30,6 +28,9 @@ run_fp16_tests() {
set -x
run_fp16_tests
run_tests "fp16"
run_tests "bf16"
run_tests "fp8"
run_tests "bf8"
set +x

View File

@@ -7,22 +7,20 @@ export CK_REPEAT=1
COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
run_fp16_tests() {
for batch in 1 2; do
for m in 128 1024; do
for n in 128 2048; do
for k in 32 64; do
run_tests() {
for m in 512 1024; do
for n in 512 2048; do
for k in 512 1024; do
$EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS
if [ $? -eq 0 ]; then
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully."
else
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
$EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS
if [ $? -eq 0 ]; then
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully."
else
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
done
done
done
done
@@ -30,6 +28,9 @@ run_fp16_tests() {
set -x
run_fp16_tests
run_tests "fp16"
run_tests "bf16"
run_tests "fp8"
run_tests "bf8"
set +x