From 0aee5c2d1668de14d2bc88c42f46b58f989dc2f2 Mon Sep 17 00:00:00 2001 From: kylasa Date: Thu, 6 Feb 2025 14:07:38 -0800 Subject: [PATCH] Support for dtypes (fp8, bf8, bf16 and fp16) for the ck_tile/03_gemm example. (#1845) * Support bf16/fb8/bf8 datatypes for ck_tile/gemm * remove commented out code. * Addressing code review comments and enabling universal_gemm for all the supported data types. * Merge conflict resolution. * Solve the memory pipeline compilation error. Merge with the new change of CShuffle * finish the feature, pass the tests * Fix the pipeline and add the benchmark script for other data types --------- Co-authored-by: ThomasNing [ROCm/composable_kernel commit: ab5d0278664d75db4dbec8c7ff864f43b22e69b9] --- example/ck_tile/03_gemm/gemm_basic.cpp | 38 ++- example/ck_tile/03_gemm/gemm_basic.hpp | 51 ++- example/ck_tile/03_gemm/run_gemm_example.inc | 29 +- .../ck_tile/03_gemm/script/benchmark_basic.sh | 3 +- .../03_gemm/script/benchmark_basic_bf16.sh | 0 .../03_gemm/script/benchmark_basic_bf8.sh | 0 .../03_gemm/script/benchmark_basic_fp8.sh | 14 + .../03_gemm/script/benchmark_mem_pipeline.sh | 6 +- .../script/benchmark_mem_pipeline_bf16.sh | 13 + .../script/benchmark_mem_pipeline_bf8.sh | 13 + .../script/benchmark_mem_pipeline_fp8.sh | 13 + .../03_gemm/script/smoke_test_basic.sh | 31 +- .../03_gemm/script/smoke_test_mem_pipeline.sh | 31 +- example/ck_tile/03_gemm/universal_gemm.cpp | 99 +++++- .../core/arch/generic_memory_space_atomic.hpp | 303 +++++++++++++++++- include/ck_tile/host.hpp | 2 +- include/ck_tile/host/check_err.hpp | 20 +- .../ck_tile/host/reference/reference_gemm.hpp | 5 +- include/ck_tile/ops/batched_transpose.hpp | 2 +- .../ops/epilogue/cshuffle_epilogue.hpp | 3 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 10 +- 21 files changed, 598 insertions(+), 88 deletions(-) create mode 100644 example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh create mode 100644 example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh create mode 100644 example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh create mode 100644 example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh create mode 100644 example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh create mode 100644 example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index b667886f84..2e04780eb0 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -12,7 +12,13 @@ #include "ck_tile/host.hpp" #include "gemm_basic.hpp" -template +template float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. @@ -25,7 +31,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& // This part comes from the Codegen constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t N_Tile = 128; - constexpr ck_tile::index_t K_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; @@ -33,7 +39,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; + constexpr ck_tile::index_t K_Warp_Tile = 16; using CodegenGemmShape = ck_tile::TileGemmShape, @@ -99,12 +105,32 @@ int run_gemm_example(int argc, char* argv[]) using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - std::string a_layout = arg_parser.get_str("a_layout"); - std::string b_layout = arg_parser.get_str("b_layout"); + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else { diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 3fdc4ac46c..5fa94f5f72 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -18,7 +18,7 @@ #define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE #endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave @@ -43,6 +43,33 @@ struct GemmBasicTypeConfig // ToDo: Add more bias config to support different categories of GEMM. }; +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + template struct DataTypeTraits; @@ -64,13 +91,23 @@ struct DataTypeTraits static constexpr const char* name = "fp16"; }; -using Types = GemmBasicTypeConfig; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; -// Specific type aliases for easy access -using ADataType = Types::ADataType; -using BDataType = Types::BDataType; -using AccDataType = Types::AccDataType; -using CDataType = Types::CDataType; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; auto create_args(int argc, char* argv[]) { diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index d32ec57be5..028f8a44c3 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -9,6 +9,7 @@ static constexpr inline auto is_row_major(Layout layout_) ck_tile::tensor_layout::gemm::RowMajor>>{}; } +template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) @@ -29,7 +30,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template +template float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf, @@ -55,7 +57,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = gemm_calc( + float ave_time = gemm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; @@ -66,13 +69,19 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C + << " A_Layout =" << ALayout::name + << " B_Layout =" << BLayout::name + << " C_Layout =" << CLayout::name + << " A Type = " << DataTypeTraits::name + << " B Type = " << DataTypeTraits::name + << " C Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; return ave_time; } -template +template int run_gemm_example_with_layouts(int argc, char* argv[], const ALayout a_layout = ALayout{}, @@ -83,6 +92,11 @@ int run_gemm_example_with_layouts(int argc, if(!result) return -1; + using ADataType = typename GemmBasicTypeConfig::ADataType; + using BDataType = typename GemmBasicTypeConfig::BDataType; + using CDataType = typename GemmBasicTypeConfig::CDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; + ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t K = arg_parser.get_int("k"); @@ -119,7 +133,8 @@ int run_gemm_example_with_layouts(int argc, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_gemm(a_m_k_dev_buf, + invoke_gemm(a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_dev_buf, M, @@ -145,7 +160,8 @@ int run_gemm_example_with_layouts(int argc, a_m_k, b_k_n, c_m_n_host_ref); const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + const auto rtol_atol = calculate_rtol_atol + (K, kbatch, max_accumulated_value); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref, "Error: Incorrect results!", @@ -202,7 +218,8 @@ int run_gemm_example_with_layouts(int argc, c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); const float max_accumulated_value = *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + const auto rtol_atol = calculate_rtol_atol + (K, kbatch, max_accumulated_value); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref, "Error: Incorrect results!", diff --git a/example/ck_tile/03_gemm/script/benchmark_basic.sh b/example/ck_tile/03_gemm/script/benchmark_basic.sh index 6c6049ef8b..a1646da5bd 100755 --- a/example/ck_tile/03_gemm/script/benchmark_basic.sh +++ b/example/ck_tile/03_gemm/script/benchmark_basic.sh @@ -2,7 +2,8 @@ EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" VALID=1 -for b_matrix_layout in "R" "C"; do + +for b_matrix_layout in "C"; do for m in "64" "512" "1024" "2048"; do for n in "512" "1024" "2048"; do for k in "64" "512" "1024" "2048"; do diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh b/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh new file mode 100644 index 0000000000..21462616be --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh @@ -0,0 +1,14 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" +VALID=1 + + +for b_matrix_layout in "C"; do + for m in "64" "512" "1024" "2048"; do + for n in "512" "1024" "2048"; do + for k in "64" "512" "1024" "2048"; do + $EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh index 8ff7d7ad44..c4cf4ddcbf 100755 --- a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh @@ -2,10 +2,10 @@ EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" VALID=1 -for b_matrix_layout in "R" "C"; do - for m in "64" "512" "1024" "2048"; do +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do for n in "512" "1024" "2048"; do - for k in "64" "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do $EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID done done diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh new file mode 100644 index 0000000000..903b4a3c0f --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=bf16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh new file mode 100644 index 0000000000..8c92c2e991 --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=bf8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh new file mode 100644 index 0000000000..e238006c7d --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/smoke_test_basic.sh b/example/ck_tile/03_gemm/script/smoke_test_basic.sh index 8eb4e101a0..7ca6759f42 100755 --- a/example/ck_tile/03_gemm/script/smoke_test_basic.sh +++ b/example/ck_tile/03_gemm/script/smoke_test_basic.sh @@ -7,22 +7,20 @@ export CK_REPEAT=1 COMMON_ARGS='-v=2 -warmup=0 -repeat=1' -run_fp16_tests() { - for batch in 1 2; do - for m in 128 1024; do - for n in 128 2048; do - for k in 32 64; do +run_tests() { + for m in 128 1024; do + for n in 128 2048; do + for k in 64 128; do - $EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS - if [ $? -eq 0 ]; then - echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." - else - echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." - # Optionally, exit or break if you need to halt further execution - # exit 1 - fi + $EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS + if [ $? -eq 0 ]; then + echo "Success: Test with m=$m, n=$n, k=$k executed successfully." + else + echo "Error: Test with m=$m, n=$n, k=$k failed to execute properly." + # Optionally, exit or break if you need to halt further execution + # exit 1 + fi - done done done done @@ -30,6 +28,9 @@ run_fp16_tests() { set -x -run_fp16_tests +run_tests "fp16" +run_tests "bf16" +run_tests "fp8" +run_tests "bf8" set +x diff --git a/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh b/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh index a9c7f48da0..951f8aa63a 100755 --- a/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh +++ b/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh @@ -7,22 +7,20 @@ export CK_REPEAT=1 COMMON_ARGS='-v=2 -warmup=0 -repeat=1' -run_fp16_tests() { - for batch in 1 2; do - for m in 128 1024; do - for n in 128 2048; do - for k in 32 64; do +run_tests() { + for m in 512 1024; do + for n in 512 2048; do + for k in 512 1024; do - $EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS - if [ $? -eq 0 ]; then - echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." - else - echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." - # Optionally, exit or break if you need to halt further execution - # exit 1 - fi + $EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS + if [ $? -eq 0 ]; then + echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." + else + echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." + # Optionally, exit or break if you need to halt further execution + # exit 1 + fi - done done done done @@ -30,6 +28,9 @@ run_fp16_tests() { set -x -run_fp16_tests +run_tests "fp16" +run_tests "bf16" +run_tests "fp8" +run_tests "bf8" set +x diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index eaaf3dbed9..08a9cdb24b 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -12,7 +12,13 @@ #include "ck_tile/host.hpp" #include "gemm_basic.hpp" -template +template float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) @@ -33,7 +39,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& // Compute friendly for Intrawave scheduler constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; @@ -243,24 +249,101 @@ int run_gemm_example(int argc, char* argv[]) using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - std::string a_layout = arg_parser.get_str("a_layout"); - std::string b_layout = arg_parser.get_str("b_layout"); + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); if(a_layout == "R" && b_layout == "R") { - return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else if(a_layout == "C" && b_layout == "R") { - return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else { diff --git a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp index 6212db9169..e6fc08c545 100644 --- a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp +++ b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core/numeric/vector_type.hpp" @@ -8,16 +8,75 @@ namespace ck_tile { -CK_TILE_HOST_DEVICE bf16_t add_bf16_t(const bf16_t& a, const bf16_t& b) +template +CK_TILE_HOST_DEVICE T add(const T& a, const T& b) { - return type_convert(type_convert(a) + type_convert(b)); + return type_convert(type_convert(a) + type_convert(b)); } CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b) { bf16x2_t rtn; - rtn[0] = add_bf16_t(a[0], b[0]); - rtn[1] = add_bf16_t(a[1], b[1]); + rtn[0] = add(a[0], b[0]); + rtn[1] = add(a[1], b[1]); + return rtn; +} + +CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b) +{ + bf16x4_t rtn; + rtn[0] = add(a[0], b[0]); + rtn[1] = add(a[1], b[1]); + rtn[2] = add(a[2], b[2]); + rtn[3] = add(a[3], b[3]); + return rtn; +} + +CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b) +{ + fp8x4_t rtn; + rtn[0] = add(a[0], b[0]); + rtn[1] = add(a[1], b[1]); + rtn[2] = add(a[2], b[2]); + rtn[3] = add(a[3], b[3]); + return rtn; +} + +CK_TILE_HOST_DEVICE fp8x8_t add_fp8x8_t(const fp8x8_t& a, const fp8x8_t& b) +{ + fp8x8_t rtn; + rtn[0] = add(a[0], b[0]); + rtn[1] = add(a[1], b[1]); + rtn[2] = add(a[2], b[2]); + rtn[3] = add(a[3], b[3]); + rtn[4] = add(a[4], b[4]); + rtn[5] = add(a[5], b[5]); + rtn[6] = add(a[6], b[6]); + rtn[7] = add(a[7], b[7]); + return rtn; +} + +CK_TILE_HOST_DEVICE bf8x4_t add_bf8x4_t(const bf8x4_t& a, const bf8x4_t& b) +{ + bf8x4_t rtn; + rtn[0] = add(a[0], b[0]); + rtn[1] = add(a[1], b[1]); + rtn[2] = add(a[2], b[2]); + rtn[3] = add(a[3], b[3]); + return rtn; +} + +CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t& a, const bf8x8_t& b) +{ + bf8x8_t rtn; + rtn[0] = add(a[0], b[0]); + rtn[1] = add(a[1], b[1]); + rtn[2] = add(a[2], b[2]); + rtn[3] = add(a[3], b[3]); + rtn[4] = add(a[4], b[4]); + rtn[5] = add(a[5], b[5]); + rtn[6] = add(a[6], b[6]); + rtn[7] = add(a[7], b[7]); return rtn; } @@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) } while(cur_v.u32 != old_v); } +template <> +CK_TILE_DEVICE void atomic_add(bf16x4_t* p_dst, bf16x4_t const& x) +{ + // Union to treat the pointer as either bf16x4_t* or uint64_t*: + union U64BF164_ADDR + { + uint64_t* u64_a; + bf16x4_t* bf164_a; + }; + + // Union to treat the data as either bf16x4_t or 64-bit integer + union U64BF164 + { + uint64_t u64; + bf16x4_t bf164; + }; + + U64BF164_ADDR addr; + addr.bf164_a = p_dst; // interpret p_dst as a 64-bit location + + // First read (non-atomic) of the old value + U64BF164 cur_v; + cur_v.u64 = *addr.u64_a; + + U64BF164 new_v_union; + uint64_t old_v, new_v; + + do + { + // old 64 bits + old_v = cur_v.u64; + + // Add elementwise in bf16 + new_v_union.bf164 = add_bf16x4_t(cur_v.bf164, x); + new_v = new_v_union.u64; + + // Attempt the 64-bit CAS + cur_v.u64 = atomicCAS(addr.u64_a, old_v, new_v); + + } while(cur_v.u64 != old_v); +} + +template <> +CK_TILE_DEVICE void atomic_add(fp8x4_t* p_dst, const fp8x4_t& x) +{ + union U32FP84_ADDR + { + uint32_t* u32_a; + fp8x4_t* fp84_a; + }; + + union U32FP84 + { + uint32_t u32; + fp8x4_t fp84; + }; + + U32FP84_ADDR dword_addr; + U32FP84 cur_v; + U32FP84 new_; + uint32_t old_v, new_v; + + dword_addr.fp84_a = p_dst; + cur_v.u32 = *dword_addr.u32_a; + + do + { + old_v = cur_v.u32; + new_.fp84 = add_fp8x4_t(cur_v.fp84, x); + new_v = new_.u32; + cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); + } while(cur_v.u32 != old_v); +} + +template <> +CK_TILE_DEVICE void atomic_add(bf8x4_t* p_dst, const bf8x4_t& x) +{ + union U32BF84_ADDR + { + uint32_t* u32_a; + bf8x4_t* bf84_a; + }; + + union U32BF84 + { + uint32_t u32; + bf8x4_t bf84; + }; + + U32BF84_ADDR dword_addr; + U32BF84 cur_v; + U32BF84 new_; + uint32_t old_v, new_v; + + dword_addr.bf84_a = p_dst; + cur_v.u32 = *dword_addr.u32_a; + + do + { + old_v = cur_v.u32; + new_.bf84 = add_bf8x4_t(cur_v.bf84, x); + new_v = new_.u32; + cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); + } while(cur_v.u32 != old_v); +} + +// +// Atomic add for fp8x8_t +// +template <> +CK_TILE_DEVICE void atomic_add(fp8x8_t* p_dst, fp8x8_t const& x) +{ + // Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer. + union U64FP88_ADDR + { + uint64_t* u64_a; // pointer to 64-bit integer + fp8x8_t* fp88_a; // pointer to fp8x8_t + }; + + union U64FP88 + { + uint64_t u64; + fp8x8_t fp88; + }; + + U64FP88_ADDR dword_addr; + U64FP88 cur_v; + U64FP88 new_v_union; + uint64_t old_v, new_v; + + // Point to the destination as both fp8x8_t* and uint64_t*. + dword_addr.fp88_a = p_dst; + // Initial read of 64 bits from memory + cur_v.u64 = *dword_addr.u64_a; + + do + { + old_v = cur_v.u64; + // Add each fp8 element using your add_fp8x8_t(...) routine + new_v_union.fp88 = add_fp8x8_t(cur_v.fp88, x); + new_v = new_v_union.u64; + + // Attempt 64-bit CAS + cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v); + } while(cur_v.u64 != old_v); +} + +// +// Atomic add for bf8x8_t +// +template <> +CK_TILE_DEVICE void atomic_add(bf8x8_t* p_dst, bf8x8_t const& x) +{ + union U64BF88_ADDR + { + uint64_t* u64_a; + bf8x8_t* bf88_a; + }; + + union U64BF88 + { + uint64_t u64; + bf8x8_t bf88; + }; + + U64BF88_ADDR dword_addr; + U64BF88 cur_v; + U64BF88 new_v_union; + uint64_t old_v, new_v; + + dword_addr.bf88_a = p_dst; + // Read the original 64 bits + cur_v.u64 = *dword_addr.u64_a; + + do + { + old_v = cur_v.u64; + // Add each bf8 element using your add_bf8x8_t(...) routine + new_v_union.bf88 = add_bf8x8_t(cur_v.bf88, x); + new_v = new_v_union.u64; + + // 64-bit CAS loop + cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v); + } while(cur_v.u64 != old_v); +} + template CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer& x) { @@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer& x) (std::is_same::value && (N == 1)) || (std::is_same::value && (N == 1 || N == 2)) || (std::is_same::value && (N == 1 || N == 2)) || - (std::is_same::value && (N == 2 || N == 4)), - "wrong! not implemented"); + (std::is_same::value && (N == 2 || N == 4 || N == 8)) || + (std::is_same::value && (N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 4 || N == 8 || N == 16)), + "The granularity of the thread buffer is unsupported on the hardware!"); constexpr auto I0 = number<0>{}; constexpr auto I1 = number<1>{}; @@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer& x) } else if constexpr(N == 4) { - atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); - atomic_add(c_style_pointer_cast(p_dst) + 1, - x.template get_as()[I1]); + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + } + else if constexpr(N == 8) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomic_add(c_style_pointer_cast(p_dst) + 1, + x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 4) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + } + if constexpr(N == 8) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + } + if constexpr(N == 16) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomic_add(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 4) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + } + if constexpr(N == 8) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + } + if constexpr(N == 16) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomic_add(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); } } } diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index bb5d8bfa86..39a904717c 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -20,6 +20,7 @@ #include "ck_tile/host/reference/reference_batched_masking.hpp" #include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp" +#include "ck_tile/host/reference/reference_batched_transpose.hpp" #include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_fused_moe.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" @@ -34,4 +35,3 @@ #include "ck_tile/host/reference/reference_topk.hpp" #include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/timer.hpp" -#include "ck_tile/host/reference/reference_batched_transpose.hpp" diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 2eff11dd25..ea70563d58 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -22,13 +22,14 @@ template ::value, + static_assert(is_any_of::value, "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); double compute_error = 0; @@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1) compute_error = std::pow(2, -numeric_traits::mant) * 0.5; } - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled OutDataType for setting up the relative threshold!"); double output_error = 0; @@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1) } double midway_error = std::max(compute_error, output_error); - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled AccDataType for setting up the relative threshold!"); double acc_error = 0; @@ -74,13 +75,14 @@ template ::value, + static_assert(is_any_of::value, "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); auto expo = std::log2(std::abs(max_possible_num)); @@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of compute_error = std::pow(2, expo - numeric_traits::mant) * 0.5; } - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled OutDataType for setting up the absolute threshold!"); double output_error = 0; @@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of } double midway_error = std::max(compute_error, output_error); - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled AccDataType for setting up the absolute threshold!"); double acc_error = 0; @@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v, ranges::range_val } if(!res) { - std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + const float error_percent = + static_cast(err_count) / static_cast(out.size()) * 100.f; + std::cerr << "max err: " << max_err; + std::cerr << ", number of errors: " << err_count; + std::cerr << ", " << error_percent << "% wrong values" << std::endl; } return res; } diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index fc412e8831..da0de457d4 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A, int b_index = (std::is_same_v) ? col * strideB + k : k * strideB + col; - acc += static_cast(A[a_index]) * static_cast(B[b_index]); + acc += ck_tile::type_convert(A[a_index]) * + ck_tile::type_convert(B[b_index]); } int c_index = (std::is_same_v) ? row * strideC + col : col * strideC + row; - C[c_index] = acc; + C[c_index] = ck_tile::type_convert(acc); } } diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index 8741e0a49b..ade2f18041 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 4aba3d7ec1..155dbad6e3 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -77,6 +77,7 @@ struct CShuffleEpilogue * * @return The vector store size for C tensor. */ + template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() { constexpr index_t MaxVectorStoreSize = 16; @@ -142,7 +143,7 @@ struct CShuffleEpilogue TileDistributionEncodingPattern2D(), tile_distribution_pattern::thread_raked>; constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 4c65f51914..aa31d1fccf 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -159,7 +159,7 @@ struct GemmKernel CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) { - if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + if constexpr(EpiloguePipeline::template GetVectorSizeC() % 2 != 0 && is_any_of::value) { if(kargs.k_batch != 1) @@ -240,7 +240,7 @@ struct GemmKernel << std::endl; return false; } - if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) + if(kargs.N % EpiloguePipeline::template GetVectorSizeC() != 0) { std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; return false; @@ -255,7 +255,7 @@ struct GemmKernel << std::endl; return false; } - if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) + if(kargs.M % EpiloguePipeline::template GetVectorSizeC() != 0) { std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; return false; @@ -321,7 +321,7 @@ struct GemmKernel c_ptr, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_C, 1), - number{}, + number()>{}, number<1>{}); } else @@ -519,7 +519,7 @@ struct GemmKernel { // Do not compile in case where we have unsupported // VectorSizeC & data type configuration. - if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + if constexpr(!(EpiloguePipeline::template GetVectorSizeC() % 2 != 0 && is_any_of::value)) { RunGemm(