diff --git a/.azuredevops/rocm-ci.yml b/.azuredevops/rocm-ci.yml index 4161c2d5a4..b37b8cc27f 100644 --- a/.azuredevops/rocm-ci.yml +++ b/.azuredevops/rocm-ci.yml @@ -14,6 +14,7 @@ trigger: branches: include: - develop + - amd-develop paths: exclude: - .github diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 7ba7d4768b..5b00b5a123 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -13,7 +13,7 @@ #include "ck/utility/blkgemmpipe_scheduler.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/sequence.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/fill.hpp" @@ -315,40 +315,27 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) std::cout << "Computing GEMM on host..." << std::endl; } - Tensor c({M, N}); - Tensor a({M, K}); - Tensor b({K, N}); - - for(int m = 0; m < M; m++) - { - for(int k = 0; k < K; k++) - { - a(m, k) = ck::type_convert(a_m_k(m, k)) * - ck::type_convert(a_m_k_scale(m, k / Scale_Block_K)); - } - } - - for(int n = 0; n < N; n++) - { - for(int k = 0; k < K; k++) - { - b(k, n) = ck::type_convert(b_k_n(k, n)) * - ck::type_convert(b_k_n_scale(k / Scale_Block_K, n)); - } - } - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm; auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); - auto ref_argument = - ref_gemm.MakeArgument(a, b, c, PassThrough{}, PassThrough{}, PassThrough{}); + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + a_m_k_scale, + b_k_n, + b_k_n_scale, + c_m_n_host_result, + PassThrough{}, + PassThrough{}, + PassThrough{}); ref_invoker.Run(ref_argument); @@ -366,8 +353,9 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl; } - res_verified = res_verified && - ck::utils::check_err(c_m_n_device_result, c, "Error: Incorrect results!"); + res_verified = res_verified && ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!"); if(config.verbosity > 0 && res_verified) std::cout << "Done." << std::endl; diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index b667886f84..2e04780eb0 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -12,7 +12,13 @@ #include "ck_tile/host.hpp" #include "gemm_basic.hpp" -template +template float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. @@ -25,7 +31,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& // This part comes from the Codegen constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t N_Tile = 128; - constexpr ck_tile::index_t K_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; @@ -33,7 +39,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; + constexpr ck_tile::index_t K_Warp_Tile = 16; using CodegenGemmShape = ck_tile::TileGemmShape, @@ -99,12 +105,32 @@ int run_gemm_example(int argc, char* argv[]) using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - std::string a_layout = arg_parser.get_str("a_layout"); - std::string b_layout = arg_parser.get_str("b_layout"); + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else { diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 3fdc4ac46c..5fa94f5f72 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -18,7 +18,7 @@ #define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE #endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave @@ -43,6 +43,33 @@ struct GemmBasicTypeConfig // ToDo: Add more bias config to support different categories of GEMM. }; +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + template struct DataTypeTraits; @@ -64,13 +91,23 @@ struct DataTypeTraits static constexpr const char* name = "fp16"; }; -using Types = GemmBasicTypeConfig; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; -// Specific type aliases for easy access -using ADataType = Types::ADataType; -using BDataType = Types::BDataType; -using AccDataType = Types::AccDataType; -using CDataType = Types::CDataType; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; auto create_args(int argc, char* argv[]) { diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index d32ec57be5..028f8a44c3 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -9,6 +9,7 @@ static constexpr inline auto is_row_major(Layout layout_) ck_tile::tensor_layout::gemm::RowMajor>>{}; } +template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) @@ -29,7 +30,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template +template float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf, @@ -55,7 +57,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = gemm_calc( + float ave_time = gemm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; @@ -66,13 +69,19 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C + << " A_Layout =" << ALayout::name + << " B_Layout =" << BLayout::name + << " C_Layout =" << CLayout::name + << " A Type = " << DataTypeTraits::name + << " B Type = " << DataTypeTraits::name + << " C Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; return ave_time; } -template +template int run_gemm_example_with_layouts(int argc, char* argv[], const ALayout a_layout = ALayout{}, @@ -83,6 +92,11 @@ int run_gemm_example_with_layouts(int argc, if(!result) return -1; + using ADataType = typename GemmBasicTypeConfig::ADataType; + using BDataType = typename GemmBasicTypeConfig::BDataType; + using CDataType = typename GemmBasicTypeConfig::CDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; + ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t K = arg_parser.get_int("k"); @@ -119,7 +133,8 @@ int run_gemm_example_with_layouts(int argc, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_gemm(a_m_k_dev_buf, + invoke_gemm(a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_dev_buf, M, @@ -145,7 +160,8 @@ int run_gemm_example_with_layouts(int argc, a_m_k, b_k_n, c_m_n_host_ref); const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + const auto rtol_atol = calculate_rtol_atol + (K, kbatch, max_accumulated_value); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref, "Error: Incorrect results!", @@ -202,7 +218,8 @@ int run_gemm_example_with_layouts(int argc, c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); const float max_accumulated_value = *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + const auto rtol_atol = calculate_rtol_atol + (K, kbatch, max_accumulated_value); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref, "Error: Incorrect results!", diff --git a/example/ck_tile/03_gemm/script/benchmark_basic.sh b/example/ck_tile/03_gemm/script/benchmark_basic.sh index 6c6049ef8b..a1646da5bd 100755 --- a/example/ck_tile/03_gemm/script/benchmark_basic.sh +++ b/example/ck_tile/03_gemm/script/benchmark_basic.sh @@ -2,7 +2,8 @@ EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" VALID=1 -for b_matrix_layout in "R" "C"; do + +for b_matrix_layout in "C"; do for m in "64" "512" "1024" "2048"; do for n in "512" "1024" "2048"; do for k in "64" "512" "1024" "2048"; do diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh b/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh new file mode 100644 index 0000000000..21462616be --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh @@ -0,0 +1,14 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" +VALID=1 + + +for b_matrix_layout in "C"; do + for m in "64" "512" "1024" "2048"; do + for n in "512" "1024" "2048"; do + for k in "64" "512" "1024" "2048"; do + $EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh index 8ff7d7ad44..c4cf4ddcbf 100755 --- a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh @@ -2,10 +2,10 @@ EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" VALID=1 -for b_matrix_layout in "R" "C"; do - for m in "64" "512" "1024" "2048"; do +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do for n in "512" "1024" "2048"; do - for k in "64" "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do $EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID done done diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh new file mode 100644 index 0000000000..903b4a3c0f --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=bf16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh new file mode 100644 index 0000000000..8c92c2e991 --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=bf8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh new file mode 100644 index 0000000000..e238006c7d --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/smoke_test_basic.sh b/example/ck_tile/03_gemm/script/smoke_test_basic.sh index 8eb4e101a0..7ca6759f42 100755 --- a/example/ck_tile/03_gemm/script/smoke_test_basic.sh +++ b/example/ck_tile/03_gemm/script/smoke_test_basic.sh @@ -7,22 +7,20 @@ export CK_REPEAT=1 COMMON_ARGS='-v=2 -warmup=0 -repeat=1' -run_fp16_tests() { - for batch in 1 2; do - for m in 128 1024; do - for n in 128 2048; do - for k in 32 64; do +run_tests() { + for m in 128 1024; do + for n in 128 2048; do + for k in 64 128; do - $EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS - if [ $? -eq 0 ]; then - echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." - else - echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." - # Optionally, exit or break if you need to halt further execution - # exit 1 - fi + $EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS + if [ $? -eq 0 ]; then + echo "Success: Test with m=$m, n=$n, k=$k executed successfully." + else + echo "Error: Test with m=$m, n=$n, k=$k failed to execute properly." + # Optionally, exit or break if you need to halt further execution + # exit 1 + fi - done done done done @@ -30,6 +28,9 @@ run_fp16_tests() { set -x -run_fp16_tests +run_tests "fp16" +run_tests "bf16" +run_tests "fp8" +run_tests "bf8" set +x diff --git a/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh b/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh index a9c7f48da0..951f8aa63a 100755 --- a/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh +++ b/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh @@ -7,22 +7,20 @@ export CK_REPEAT=1 COMMON_ARGS='-v=2 -warmup=0 -repeat=1' -run_fp16_tests() { - for batch in 1 2; do - for m in 128 1024; do - for n in 128 2048; do - for k in 32 64; do +run_tests() { + for m in 512 1024; do + for n in 512 2048; do + for k in 512 1024; do - $EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS - if [ $? -eq 0 ]; then - echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." - else - echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." - # Optionally, exit or break if you need to halt further execution - # exit 1 - fi + $EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS + if [ $? -eq 0 ]; then + echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." + else + echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." + # Optionally, exit or break if you need to halt further execution + # exit 1 + fi - done done done done @@ -30,6 +28,9 @@ run_fp16_tests() { set -x -run_fp16_tests +run_tests "fp16" +run_tests "bf16" +run_tests "fp8" +run_tests "bf8" set +x diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index eaaf3dbed9..08a9cdb24b 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -12,7 +12,13 @@ #include "ck_tile/host.hpp" #include "gemm_basic.hpp" -template +template float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) @@ -33,7 +39,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& // Compute friendly for Intrawave scheduler constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; @@ -243,24 +249,101 @@ int run_gemm_example(int argc, char* argv[]) using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - std::string a_layout = arg_parser.get_str("a_layout"); - std::string b_layout = arg_parser.get_str("b_layout"); + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); if(a_layout == "R" && b_layout == "R") { - return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else if(a_layout == "C" && b_layout == "R") { - return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } } else { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp index 6d5aeed84b..a8cf681995 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -138,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 0) { arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); } if(!GridwiseGemm::CheckValidity(arg)) @@ -745,7 +746,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 +CK_TILE_HOST_DEVICE T add(const T& a, const T& b) { - return type_convert(type_convert(a) + type_convert(b)); + return type_convert(type_convert(a) + type_convert(b)); } CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b) { bf16x2_t rtn; - rtn[0] = add_bf16_t(a[0], b[0]); - rtn[1] = add_bf16_t(a[1], b[1]); + rtn[0] = add(a[0], b[0]); + rtn[1] = add(a[1], b[1]); + return rtn; +} + +CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b) +{ + bf16x4_t rtn; + rtn[0] = add(a[0], b[0]); + rtn[1] = add(a[1], b[1]); + rtn[2] = add(a[2], b[2]); + rtn[3] = add(a[3], b[3]); + return rtn; +} + +CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b) +{ + fp8x4_t rtn; + rtn[0] = add(a[0], b[0]); + rtn[1] = add(a[1], b[1]); + rtn[2] = add(a[2], b[2]); + rtn[3] = add(a[3], b[3]); + return rtn; +} + +CK_TILE_HOST_DEVICE fp8x8_t add_fp8x8_t(const fp8x8_t& a, const fp8x8_t& b) +{ + fp8x8_t rtn; + rtn[0] = add(a[0], b[0]); + rtn[1] = add(a[1], b[1]); + rtn[2] = add(a[2], b[2]); + rtn[3] = add(a[3], b[3]); + rtn[4] = add(a[4], b[4]); + rtn[5] = add(a[5], b[5]); + rtn[6] = add(a[6], b[6]); + rtn[7] = add(a[7], b[7]); + return rtn; +} + +CK_TILE_HOST_DEVICE bf8x4_t add_bf8x4_t(const bf8x4_t& a, const bf8x4_t& b) +{ + bf8x4_t rtn; + rtn[0] = add(a[0], b[0]); + rtn[1] = add(a[1], b[1]); + rtn[2] = add(a[2], b[2]); + rtn[3] = add(a[3], b[3]); + return rtn; +} + +CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t& a, const bf8x8_t& b) +{ + bf8x8_t rtn; + rtn[0] = add(a[0], b[0]); + rtn[1] = add(a[1], b[1]); + rtn[2] = add(a[2], b[2]); + rtn[3] = add(a[3], b[3]); + rtn[4] = add(a[4], b[4]); + rtn[5] = add(a[5], b[5]); + rtn[6] = add(a[6], b[6]); + rtn[7] = add(a[7], b[7]); return rtn; } @@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) } while(cur_v.u32 != old_v); } +template <> +CK_TILE_DEVICE void atomic_add(bf16x4_t* p_dst, bf16x4_t const& x) +{ + // Union to treat the pointer as either bf16x4_t* or uint64_t*: + union U64BF164_ADDR + { + uint64_t* u64_a; + bf16x4_t* bf164_a; + }; + + // Union to treat the data as either bf16x4_t or 64-bit integer + union U64BF164 + { + uint64_t u64; + bf16x4_t bf164; + }; + + U64BF164_ADDR addr; + addr.bf164_a = p_dst; // interpret p_dst as a 64-bit location + + // First read (non-atomic) of the old value + U64BF164 cur_v; + cur_v.u64 = *addr.u64_a; + + U64BF164 new_v_union; + uint64_t old_v, new_v; + + do + { + // old 64 bits + old_v = cur_v.u64; + + // Add elementwise in bf16 + new_v_union.bf164 = add_bf16x4_t(cur_v.bf164, x); + new_v = new_v_union.u64; + + // Attempt the 64-bit CAS + cur_v.u64 = atomicCAS(addr.u64_a, old_v, new_v); + + } while(cur_v.u64 != old_v); +} + +template <> +CK_TILE_DEVICE void atomic_add(fp8x4_t* p_dst, const fp8x4_t& x) +{ + union U32FP84_ADDR + { + uint32_t* u32_a; + fp8x4_t* fp84_a; + }; + + union U32FP84 + { + uint32_t u32; + fp8x4_t fp84; + }; + + U32FP84_ADDR dword_addr; + U32FP84 cur_v; + U32FP84 new_; + uint32_t old_v, new_v; + + dword_addr.fp84_a = p_dst; + cur_v.u32 = *dword_addr.u32_a; + + do + { + old_v = cur_v.u32; + new_.fp84 = add_fp8x4_t(cur_v.fp84, x); + new_v = new_.u32; + cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); + } while(cur_v.u32 != old_v); +} + +template <> +CK_TILE_DEVICE void atomic_add(bf8x4_t* p_dst, const bf8x4_t& x) +{ + union U32BF84_ADDR + { + uint32_t* u32_a; + bf8x4_t* bf84_a; + }; + + union U32BF84 + { + uint32_t u32; + bf8x4_t bf84; + }; + + U32BF84_ADDR dword_addr; + U32BF84 cur_v; + U32BF84 new_; + uint32_t old_v, new_v; + + dword_addr.bf84_a = p_dst; + cur_v.u32 = *dword_addr.u32_a; + + do + { + old_v = cur_v.u32; + new_.bf84 = add_bf8x4_t(cur_v.bf84, x); + new_v = new_.u32; + cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); + } while(cur_v.u32 != old_v); +} + +// +// Atomic add for fp8x8_t +// +template <> +CK_TILE_DEVICE void atomic_add(fp8x8_t* p_dst, fp8x8_t const& x) +{ + // Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer. + union U64FP88_ADDR + { + uint64_t* u64_a; // pointer to 64-bit integer + fp8x8_t* fp88_a; // pointer to fp8x8_t + }; + + union U64FP88 + { + uint64_t u64; + fp8x8_t fp88; + }; + + U64FP88_ADDR dword_addr; + U64FP88 cur_v; + U64FP88 new_v_union; + uint64_t old_v, new_v; + + // Point to the destination as both fp8x8_t* and uint64_t*. + dword_addr.fp88_a = p_dst; + // Initial read of 64 bits from memory + cur_v.u64 = *dword_addr.u64_a; + + do + { + old_v = cur_v.u64; + // Add each fp8 element using your add_fp8x8_t(...) routine + new_v_union.fp88 = add_fp8x8_t(cur_v.fp88, x); + new_v = new_v_union.u64; + + // Attempt 64-bit CAS + cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v); + } while(cur_v.u64 != old_v); +} + +// +// Atomic add for bf8x8_t +// +template <> +CK_TILE_DEVICE void atomic_add(bf8x8_t* p_dst, bf8x8_t const& x) +{ + union U64BF88_ADDR + { + uint64_t* u64_a; + bf8x8_t* bf88_a; + }; + + union U64BF88 + { + uint64_t u64; + bf8x8_t bf88; + }; + + U64BF88_ADDR dword_addr; + U64BF88 cur_v; + U64BF88 new_v_union; + uint64_t old_v, new_v; + + dword_addr.bf88_a = p_dst; + // Read the original 64 bits + cur_v.u64 = *dword_addr.u64_a; + + do + { + old_v = cur_v.u64; + // Add each bf8 element using your add_bf8x8_t(...) routine + new_v_union.bf88 = add_bf8x8_t(cur_v.bf88, x); + new_v = new_v_union.u64; + + // 64-bit CAS loop + cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v); + } while(cur_v.u64 != old_v); +} + template CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer& x) { @@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer& x) (std::is_same::value && (N == 1)) || (std::is_same::value && (N == 1 || N == 2)) || (std::is_same::value && (N == 1 || N == 2)) || - (std::is_same::value && (N == 2 || N == 4)), - "wrong! not implemented"); + (std::is_same::value && (N == 2 || N == 4 || N == 8)) || + (std::is_same::value && (N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 4 || N == 8 || N == 16)), + "The granularity of the thread buffer is unsupported on the hardware!"); constexpr auto I0 = number<0>{}; constexpr auto I1 = number<1>{}; @@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer& x) } else if constexpr(N == 4) { - atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); - atomic_add(c_style_pointer_cast(p_dst) + 1, - x.template get_as()[I1]); + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + } + else if constexpr(N == 8) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomic_add(c_style_pointer_cast(p_dst) + 1, + x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 4) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + } + if constexpr(N == 8) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + } + if constexpr(N == 16) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomic_add(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 4) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + } + if constexpr(N == 8) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + } + if constexpr(N == 16) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomic_add(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); } } } diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index bb5d8bfa86..39a904717c 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -20,6 +20,7 @@ #include "ck_tile/host/reference/reference_batched_masking.hpp" #include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp" +#include "ck_tile/host/reference/reference_batched_transpose.hpp" #include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_fused_moe.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" @@ -34,4 +35,3 @@ #include "ck_tile/host/reference/reference_topk.hpp" #include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/timer.hpp" -#include "ck_tile/host/reference/reference_batched_transpose.hpp" diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 2eff11dd25..ea70563d58 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -22,13 +22,14 @@ template ::value, + static_assert(is_any_of::value, "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); double compute_error = 0; @@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1) compute_error = std::pow(2, -numeric_traits::mant) * 0.5; } - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled OutDataType for setting up the relative threshold!"); double output_error = 0; @@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1) } double midway_error = std::max(compute_error, output_error); - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled AccDataType for setting up the relative threshold!"); double acc_error = 0; @@ -74,13 +75,14 @@ template ::value, + static_assert(is_any_of::value, "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); auto expo = std::log2(std::abs(max_possible_num)); @@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of compute_error = std::pow(2, expo - numeric_traits::mant) * 0.5; } - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled OutDataType for setting up the absolute threshold!"); double output_error = 0; @@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of } double midway_error = std::max(compute_error, output_error); - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled AccDataType for setting up the absolute threshold!"); double acc_error = 0; @@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v, ranges::range_val } if(!res) { - std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + const float error_percent = + static_cast(err_count) / static_cast(out.size()) * 100.f; + std::cerr << "max err: " << max_err; + std::cerr << ", number of errors: " << err_count; + std::cerr << ", " << error_percent << "% wrong values" << std::endl; } return res; } diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index fc412e8831..da0de457d4 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A, int b_index = (std::is_same_v) ? col * strideB + k : k * strideB + col; - acc += static_cast(A[a_index]) * static_cast(B[b_index]); + acc += ck_tile::type_convert(A[a_index]) * + ck_tile::type_convert(B[b_index]); } int c_index = (std::is_same_v) ? row * strideC + col : col * strideC + row; - C[c_index] = acc; + C[c_index] = ck_tile::type_convert(acc); } } diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index 8741e0a49b..ade2f18041 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 4aba3d7ec1..155dbad6e3 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -77,6 +77,7 @@ struct CShuffleEpilogue * * @return The vector store size for C tensor. */ + template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() { constexpr index_t MaxVectorStoreSize = 16; @@ -142,7 +143,7 @@ struct CShuffleEpilogue TileDistributionEncodingPattern2D(), tile_distribution_pattern::thread_raked>; constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 646d380a18..ab21398b99 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -79,7 +79,10 @@ struct BlockUniversalGemmAsBsCr // TODO: Should we have two policies? Interwave & Intrawave ?? static constexpr index_t InterWaveSchedulingMacClusters = 1; - static constexpr index_t KPack = WarpGemm::kKPerThread; + // should be at least equal to: WarpGemm::Impl::kABKPerLane + // and the question is how to assess upper limit or exact value? + // TODO: Should we introduce AK1/BK1 parameters ? + static constexpr index_t KPack = 8; static constexpr index_t KPerThread = KIterPerWarp * KPack; static constexpr index_t KRepeat = KPerThread / KPack; }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 4c65f51914..aa31d1fccf 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -159,7 +159,7 @@ struct GemmKernel CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) { - if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + if constexpr(EpiloguePipeline::template GetVectorSizeC() % 2 != 0 && is_any_of::value) { if(kargs.k_batch != 1) @@ -240,7 +240,7 @@ struct GemmKernel << std::endl; return false; } - if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) + if(kargs.N % EpiloguePipeline::template GetVectorSizeC() != 0) { std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; return false; @@ -255,7 +255,7 @@ struct GemmKernel << std::endl; return false; } - if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) + if(kargs.M % EpiloguePipeline::template GetVectorSizeC() != 0) { std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; return false; @@ -321,7 +321,7 @@ struct GemmKernel c_ptr, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_C, 1), - number{}, + number()>{}, number<1>{}); } else @@ -519,7 +519,7 @@ struct GemmKernel { // Do not compile in case where we have unsupported // VectorSizeC & data type configuration. - if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + if constexpr(!(EpiloguePipeline::template GetVectorSizeC() % 2 != 0 && is_any_of::value)) { RunGemm( diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 0bd7807238..0a40ca359e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -3,6 +3,9 @@ #pragma once +#include +#include + #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" @@ -83,6 +86,56 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 return Policy::template GetSmemSize(); } + CK_TILE_HOST static std::string Print() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + // Below should be equal to AK1|BK1 + constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + auto str = std::stringstream{}; + + str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << "\n" + << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n" + << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num + << "\n" + << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num + << "\n" + << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n" + << "C MFMA inst: " << C_MFMA_Inst_Num << "\n" + << "KPack: " << BlockGemm::Traits::KPack << "\n" + << "PrefetchStages: " << PrefetchStages << "\n"; + return str.str(); + } + template struct PipelineImpl : public PipelineImplBase { @@ -95,29 +148,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 CK_TILE_DEVICE static constexpr auto HotLoopScheduler() { - constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{}); - constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{}); - constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{}); + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; constexpr index_t WaveSize = 64; constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); - constexpr index_t A_LDS_Read_Width = KPerXDL; - constexpr index_t B_LDS_Read_Width = KPerXDL; + // Below should be equal to AK1|BK1 + constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB(); constexpr index_t A_Buffer_Load_Inst_Num = MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); constexpr index_t B_Buffer_Load_Inst_Num = NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); - constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL); - constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL); + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); constexpr index_t A_LDS_Read_Inst_Num = - WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL); + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); constexpr index_t B_LDS_Read_Inst_Num = - WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL); + WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index 38c663f4c3..e23f0cda7d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -90,7 +90,7 @@ struct BaseGemmPipelineAgBgCrMem // LocalPreFillStages: 1 // LocalPreFetchStages: 0 // LocalSharedMemoryBuffer: 1 -template +template struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { using Base = BaseGemmPipelineAgBgCrMem; @@ -165,11 +165,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem "A/B Dram block window should have the same data type as appropriate " "([A|B]DataType) defined in Problem definition!"); - static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}], - "A/B block window appropriate sizes must be equal to MPerBlock/NPerblock" - " or KPerBlock!"); + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); // ------------------------------------------------------------------------------------ // Definitions of all needed tiles @@ -213,25 +224,59 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem tuple_array a_block_tiles; tuple_array b_block_tiles; + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + // ----------------------------------------------------------------------------------------- // Gemm pipeline start // prefetch // global read 0 - Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); - Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window); + Base::GlobalPrefetch( + a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch( + b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); - Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + } // Global prefetch [1, PrefetchStages] static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - Base::GlobalPrefetch(a_block_tiles.get(number{}), a_copy_dram_window); - Base::GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window); + Base::GlobalPrefetch(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tiles.get(number{}), + b_copy_dram_window, + b_dram_tile_window_step); }); // main body @@ -247,19 +292,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem block_sync_lds(); - Base::LocalPrefill( - a_copy_lds_window, - a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - a_element_func); - Base::LocalPrefill( - b_copy_lds_window, - b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - b_element_func); + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d( + a_shuffle_tmp, + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill( + a_copy_lds_window, + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d( + b_shuffle_tmp, + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill( + b_copy_lds_window, + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + b_element_func); + } Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window); + a_copy_dram_window, + a_dram_tile_window_step); Base::GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window); + b_copy_dram_window, + b_dram_tile_window_step); }); i += PrefetchStages; @@ -275,12 +346,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem block_sync_lds(); - Base::LocalPrefill(a_copy_lds_window, - a_block_tiles.get(number{}), - a_element_func); - Base::LocalPrefill(b_copy_lds_window, - b_block_tiles.get(number{}), - b_element_func); + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, + a_block_tiles.get(number{}), + a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, + b_block_tiles.get(number{}), + b_element_func); + } }); block_sync_lds(); @@ -352,11 +443,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem "A/B Dram block window should have the same data type as appropriate " "([A|B]DataType) defined in Problem definition!"); - static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}], - "A/B block window appropriate sizes must be equal to MPerBlock/NPerblock" - " or KPerBlock!"); + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); // ------------------------------------------------------------------------------------ // Definitions of all needed tiles @@ -400,25 +502,58 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem tuple_array a_block_tiles; tuple_array b_block_tiles; + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); // ----------------------------------------------------------------------------------------- // Gemm pipeline start // prefetch // global read 0 - Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); - Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window); + Base::GlobalPrefetch( + a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch( + b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); - Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + } // Global prefetch [1, PrefetchStages] static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - Base::GlobalPrefetch(a_block_tiles.get(number{}), a_copy_dram_window); - Base::GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window); + Base::GlobalPrefetch(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tiles.get(number{}), + b_copy_dram_window, + b_dram_tile_window_step); }); // main body @@ -432,19 +567,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); // no second block_sync_lds because it's interwave - Base::LocalPrefill( - a_copy_lds_window, - a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - a_element_func); - Base::LocalPrefill( - b_copy_lds_window, - b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - b_element_func); + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d( + a_shuffle_tmp, + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill( + a_copy_lds_window, + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d( + b_shuffle_tmp, + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill( + b_copy_lds_window, + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + b_element_func); + } Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window); + a_copy_dram_window, + a_dram_tile_window_step); Base::GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window); + b_copy_dram_window, + b_dram_tile_window_step); }); i += PrefetchStages; @@ -457,12 +618,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); // no second block_sync_lds because it's interwave - Base::LocalPrefill(a_copy_lds_window, - a_block_tiles.get(number{}), - a_element_func); - Base::LocalPrefill(b_copy_lds_window, - b_block_tiles.get(number{}), - b_element_func); + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, + a_block_tiles.get(number{}), + a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, + b_block_tiles.get(number{}), + b_element_func); + } }); block_sync_lds(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 33f105a435..2a9683b36e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { - using ADataType = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; @@ -519,7 +518,7 @@ struct UniversalGemmPipelineAgBgCrPolicy using ALayout = remove_cvref_t; static_assert(std::is_same_v); constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t VecLoadSize = GetVectorSizeA(); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp new file mode 100644 index 0000000000..649f130c41 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceMXGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& a_m_kblock_scales, + const Tensor& b_k_n, + const Tensor& b_kblock_n_scales, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + a_m_kblock_scales_{a_m_kblock_scales}, + b_k_n_{b_k_n}, + b_kblock_n_scales_{b_kblock_n_scales}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& a_m_kblock_scales_; + const Tensor& b_k_n_; + const Tensor& b_kblock_n_scales_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceMXGemm::Argument; + + float Run(const Argument& arg) + { + using GemmInstance = ck::tensor_operation::host::ReferenceGemm; + + Tensor a_m_k_scaled(arg.a_m_k_.mDesc); + Tensor b_k_n_scaled(arg.b_k_n_.mDesc); + + const auto M = arg.a_m_k_.mDesc.GetLengths()[0]; + const auto N = arg.b_k_n_.mDesc.GetLengths()[1]; + const auto K = arg.a_m_k_.mDesc.GetLengths()[1]; + const auto SCALE_BLOCK = K / arg.a_m_kblock_scales_.mDesc.GetLengths()[1]; + + for(size_t m = 0; m < M; m++) + { + for(size_t k = 0; k < K; k++) + { + a_m_k_scaled(m, k) = + type_convert(arg.a_m_k_(m, k)) * + type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + } + } + + for(size_t n = 0; n < N; n++) + { + for(size_t k = 0; k < K; k++) + { + b_k_n_scaled(k, n) = + type_convert(arg.b_k_n_(k, n)) * + type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + } + } + + auto ref_gemm = GemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_argument = ref_gemm.MakeArgument(a_m_k_scaled, + b_k_n_scaled, + arg.c_m_n_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + + ref_invoker.Run(ref_argument); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& a_m_kblock_scales, + const Tensor& b_k_n, + const Tensor& b_kblock_n_scales, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, + a_m_kblock_scales, + b_k_n, + b_kblock_n_scales, + c_m_n, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceMXGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp index 379a005024..13818b4f95 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp @@ -72,7 +72,7 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple >; template -using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple < +using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| @@ -86,16 +86,34 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 8, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 64, 128, 8, 8, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - // Memory friendly - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 16, 256, 8, 8, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 256, 8, 8, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 8, 8, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 8, 8, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 64, 256, 8, 8, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> - // clang-format on - >; + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 4, 4, 32, 32, 2, 1, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 2, 2, 32, 32, 2, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 4, 4, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 2, 2, 32, 32, 2, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 2, 2, 32, 32, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> #endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/test/ck_tile/gemm/test_gemm_pipeline.cpp b/test/ck_tile/gemm/test_gemm_pipeline.cpp index faffe848d5..5193f2db20 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -14,28 +14,26 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Intrawave = ck_tile::integral_constant; -// using Interwave = ck_tile::integral_constant; -// using Mem = ck_tile::integral_constant; -using Comp = ck_tile::integral_constant; - -// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors. +using Interwave = ck_tile::integral_constant; +using Mem = ck_tile::integral_constant; +using Comp = ck_tile::integral_constant; // clang-format off using KernelTypes = ::testing::Types< // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType - // std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, + std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, - // std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>, - // std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, + std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, - // std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>, - // std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, - // std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>, - // std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp> - // std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem> + std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem> >; // clang-format on