diff --git a/Jenkinsfile b/Jenkinsfile index 1d23daec25..a35b0e1892 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -351,12 +351,12 @@ def cmake_build(Map conf=[:]){ } if (params.RUN_CK_TILE_GEMM_TESTS){ try{ - archiveArtifacts "perf_tile_gemm_*.log" + archiveArtifacts "perf_tile_gemm_**.log" if (arch_type == 1){ - stash includes: "perf_tile_gemm_**_fp16_gfx90a.log", name: "perf_tile_gemm_log_gfx90a" + stash includes: "perf_tile_gemm_**_gfx90a.log", name: "perf_tile_gemm_log_gfx90a" } else if (arch_type == 2){ - stash includes: "perf_tile_gemm_**_fp16_gfx942.log", name: "perf_tile_gemm_log_gfx942" + stash includes: "perf_tile_gemm_**_gfx942.log", name: "perf_tile_gemm_log_gfx942" } } catch(Exception err){ @@ -799,8 +799,8 @@ pipeline { description: "Run the ck_tile FMHA tests (default: OFF)") booleanParam( name: "RUN_CK_TILE_GEMM_TESTS", - defaultValue: true, - description: "Run the ck_tile GEMM tests (default: ON)") + defaultValue: false, + description: "Run the ck_tile GEMM tests (default: OFF)") booleanParam( name: "BUILD_INSTANCES_ONLY", defaultValue: false, diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 57298b68dc..69051423fb 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -29,8 +29,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& constexpr int kBlockPerCu = 1; // This part comes from the Codegen - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; @@ -54,7 +54,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& GemmPipelineProblem; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if constexpr(std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices when " + "BPrecType is ck_tile::pk_int4_t!"); + } + } + else + { + if(a_layout == "R" && b_layout == "R") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_example_with_layouts( + argc, argv, Col{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } +} + int run_gemm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); if(!result) return -1; - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - std::string data_type = arg_parser.get_str("prec"); std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); - if(a_layout == "R" && b_layout == "C") + if(data_type == "fp16") { - if(data_type == "fp16") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } - else if(data_type == "bf16") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } - else if(data_type == "fp8") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } - else if(data_type == "bf8") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported data_type!"); - } + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); } + else if(data_type == "bf16") + { + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "fp8") + { + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + else if(data_type == "bf8") + { + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + else if(data_type == "pk_int4_t") + { + // TODO: Add support for bhalf_t ADataType + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } +#endif else { - throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + throw std::runtime_error("Unsupported data type for this operation !!!"); } } diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 988f8319b5..3254a407fd 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -114,7 +114,7 @@ struct GemmTypeConfig }; template <> -struct GemmTypeConfig +struct GemmTypeConfig { using ADataType = ck_tile::bf16_t; using BDataType = ck_tile::bf16_t; @@ -123,7 +123,7 @@ struct GemmTypeConfig }; template <> -struct GemmTypeConfig +struct GemmTypeConfig { using ADataType = ck_tile::fp8_t; using BDataType = ck_tile::fp8_t; @@ -132,7 +132,7 @@ struct GemmTypeConfig }; template <> -struct GemmTypeConfig +struct GemmTypeConfig { using ADataType = ck_tile::bf8_t; using BDataType = ck_tile::bf8_t; diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh old mode 100644 new mode 100755 diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh old mode 100644 new mode 100755 diff --git a/example/ck_tile/03_gemm/script/benchmark_basic.sh b/example/ck_tile/03_gemm/script/benchmark_basic_fp16.sh similarity index 100% rename from example/ck_tile/03_gemm/script/benchmark_basic.sh rename to example/ck_tile/03_gemm/script/benchmark_basic_fp16.sh diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh b/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh old mode 100644 new mode 100755 diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh old mode 100644 new mode 100755 diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh old mode 100644 new mode 100755 diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp16.sh similarity index 100% rename from example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh rename to example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp16.sh diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh old mode 100644 new mode 100755 diff --git a/example/ck_tile/03_gemm/script/run_full_test.sh b/example/ck_tile/03_gemm/script/run_full_test.sh index 45bd1bed61..2448acbad2 100755 --- a/example/ck_tile/03_gemm/script/run_full_test.sh +++ b/example/ck_tile/03_gemm/script/run_full_test.sh @@ -32,14 +32,11 @@ function print_log_header(){ } # run verification tests -example/ck_tile/03_gemm/script/smoke_test_basic.sh example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh # run performance benchmarks -export gemm_basic_log="perf_tile_gemm_basic_fp16_$GPU_arch.log" -print_log_header $gemm_basic_log $env_type $branch $host_name -example/ck_tile/03_gemm/script/benchmark_basic.sh 2>&1 | tee -a $gemm_basic_log - -export gemm_mem_pipeline_log="perf_tile_gemm_mem_pipeline_fp16_$GPU_arch.log" -print_log_header $gemm_mem_pipeline_log $env_type $branch $host_name -example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh 2>&1 | tee -a $gemm_mem_pipeline_log +for dtype in fp16 bf16 fp8 bf8; do + export gemm_log="perf_tile_gemm_mem_pipeline_${dtype}_${GPU_arch}.log" + print_log_header $gemm_log $env_type $branch $host_name + example/ck_tile/03_gemm/script/benchmark_mem_pipeline_$dtype.sh 2>&1 | tee -a $gemm_log +done diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 8c04066b20..eef8d3b60e 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -76,7 +76,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& using GemmPipeline = GEMM_PIPELINE; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem{}, ck_tile::integral_constant{}); } + else if(tail_num == ck_tile::TailNumber::Odd) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } else { std::ostringstream err; @@ -205,11 +217,29 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& } else { - std::ostringstream err; - err << "Num K loop must be larger than number of prefetech stages." - << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "Num K loop must be larger than number of prefetech stages." + << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } } return ave_time; @@ -217,133 +247,113 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& #include "run_gemm_example.inc" +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if constexpr(std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices when " + "BPrecType is ck_tile::pk_int4_t!"); + } + } + else + { + if(a_layout == "R" && b_layout == "R") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_example_with_layouts( + argc, argv, Col{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } +} + int run_gemm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); if(!result) return -1; - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - std::string data_type = arg_parser.get_str("prec"); std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); - if(a_layout == "R" && b_layout == "R") + if(data_type == "fp16") { - if(data_type == "fp16") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - } - else if(data_type == "bf16") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - } - else if(data_type == "fp8") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - } - else if(data_type == "bf8") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported data_type!"); - } + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); } - else if(a_layout == "R" && b_layout == "C") + else if(data_type == "bf16") { - if(data_type == "fp16") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } - else if(data_type == "bf16") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } - else if(data_type == "fp8") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } - else if(data_type == "bf8") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "fp8") + { + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + else if(data_type == "bf8") + { + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) - else if(data_type == "pk_int4_t") - { - // TODO: Add support for bhalf_t ADataType - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } -#endif - else - { - throw std::runtime_error("Unsupported data_type!"); - } - } - else if(a_layout == "C" && b_layout == "C") + else if(data_type == "pk_int4_t") { - if(data_type == "fp16") - { - return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); - } - else if(data_type == "bf16") - { - return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); - } - else if(data_type == "fp8") - { - return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); - } - else if(data_type == "bf8") - { - return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); - } -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) - else if(data_type == "pk_int4_t") - { - // TODO: Add support for bhalf_t ADataType - return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); - } + // TODO: Add support for bhalf_t ADataType + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } #endif - else - { - throw std::runtime_error("Unsupported data_type!"); - } - } - else if(a_layout == "C" && b_layout == "R") - { - if(data_type == "fp16") - { - return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); - } - else if(data_type == "bf16") - { - return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); - } - else if(data_type == "fp8") - { - return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); - } - else if(data_type == "bf8") - { - return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported data_type!"); - } - } else { - throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + throw std::runtime_error("Unsupported data type for this operation !!!"); } } -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + try + { + run_gemm_example(argc, argv); + } + catch(const std::runtime_error& e) + { + std::cerr << "Caught runtime error: " << e.what() << '\n'; + // Return a non-zero code to indicate failure + return EXIT_FAILURE; + } + return EXIT_SUCCESS; +} diff --git a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp index e6fc08c545..07c6aa0baf 100644 --- a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp +++ b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp @@ -361,7 +361,7 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer& x) { if constexpr(N == 2) { - atomic_add(c_style_pointer_cast(p_dst), bit_cast(x)); + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); } else if constexpr(N == 4) { diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index facc3e45ee..a4e8ca6a2b 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -523,7 +523,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x) int exponent = (x & 0x7F) >> SrcT_mant; if constexpr(is_fnuz) { - if(x == 0x80) + if((x & 0xff) == 0x80) { return fNaN; } diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 155dbad6e3..0081edcb2e 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -9,7 +9,9 @@ namespace ck_tile { -template struct CShuffleEpilogueProblem { + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using CLayout = remove_cvref_t; @@ -40,9 +44,13 @@ struct CShuffleEpilogueProblem template struct CShuffleEpilogue { - using Problem = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; + using Problem = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BTypeToUse = + std::conditional_t, ODataType, BDataType>; using CLayout = remove_cvref_t; static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = Problem::kMPerBlock; @@ -56,8 +64,8 @@ struct CShuffleEpilogue static constexpr index_t kMPerIteration = kMPerXdl * kMWave; static constexpr index_t kNPerIteration = kNPerXdl * kNWave; - using WG = WarpGemmMfmaDispatcher CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() { constexpr index_t MaxVectorStoreSize = 16; @@ -143,7 +150,7 @@ struct CShuffleEpilogue TileDistributionEncodingPattern2D(), + GetVectorSizeC(), tile_distribution_pattern::thread_raked>; constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 972c71e93b..503a92b863 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -167,7 +167,7 @@ struct GemmKernel CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) { - if constexpr(EpiloguePipeline::template GetVectorSizeC() % 2 != 0 && + if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value) { if(kargs.k_batch != 1) @@ -275,7 +275,7 @@ struct GemmKernel } return false; } - if(kargs.N % EpiloguePipeline::template GetVectorSizeC() != 0) + if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { @@ -295,7 +295,7 @@ struct GemmKernel } return false; } - if(kargs.M % EpiloguePipeline::template GetVectorSizeC() != 0) + if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { @@ -407,7 +407,7 @@ struct GemmKernel c_ptr, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_C, 1), - number()>{}, + number{}, number<1>{}); } else @@ -671,7 +671,7 @@ struct GemmKernel } else { - if constexpr(!(EpiloguePipeline::template GetVectorSizeC() % 2 != 0 && + if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) { RunGemm2LDS(a_ptr, @@ -694,7 +694,7 @@ struct GemmKernel } else { - if constexpr(!(EpiloguePipeline::template GetVectorSizeC() % 2 != 0 && + if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) { RunGemm( diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 71d8ef1b3d..c198c9443a 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -33,8 +33,21 @@ struct BaseGemmPipelineAgBgCrCompV3 CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) { - ignore = num_loop; - return TailNumber::Full; + if(BlockHasHotloop(num_loop)) + { + return TailNumber::Full; + } + else + { + if(num_loop == 1) + { + return TailNumber::Odd; + } + else + { + return TailNumber::Even; + } + } } }; @@ -470,6 +483,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); @@ -478,12 +492,43 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 } while(i < (num_loop - 1)); } // tail - if constexpr(TailNum == TailNumber::Full) + if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) { + // Leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + else + { + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } - // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle - // latency // __builtin_amdgcn_sched_barrier(0); return c_block_tile; } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index f95d80a6f5..0e0ee9dbd8 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -143,7 +143,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 constexpr index_t A_LDS_Read_Inst_Num = WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL); constexpr index_t B_LDS_Read_Inst_Num = - WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL); + WaveNumM * NPerBlock * KPerBlock / (BlockSize * KPerXDL); constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / @@ -442,6 +442,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 Base::LocalPrefill( b_copy_lds_window1, b_global_load_tile, b_element_func); } + block_sync_lds(); Base::GlobalPrefetch( a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 41ea89b2bd..2a10389ce6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -26,6 +26,10 @@ struct GemmPipelineAGmemBGmemCRegV1 using BlockGemm = remove_cvref_t())>; + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = BlockGemmShape::kM; @@ -81,11 +85,21 @@ struct GemmPipelineAGmemBGmemCRegV1 std::is_same_v>, "wrong!"); - static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], - "wrong!"); + constexpr bool is_a_col_major = std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + static_assert(is_a_col_major + ? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); // A tile in LDS ADataType* p_a_lds = static_cast(p_smem); diff --git a/script/process_perf_data.py b/script/process_perf_data.py index 0d56c9baa2..2dd54fa62d 100644 --- a/script/process_perf_data.py +++ b/script/process_perf_data.py @@ -344,6 +344,30 @@ def main(): for i in range(1, len(results)+1): testlist.append("Test%i"%i) table_name="ck_tile_gemm_mem_pipeline_fp16_tflops" + if 'gemm_basic_bf16' in filename: + for i in range(1, len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_tile_gemm_basic_bf16_tflops" + if 'gemm_mem_pipeline_bf16' in filename: + for i in range(1, len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_tile_gemm_mem_pipeline_bf16_tflops" + if 'gemm_basic_fp8' in filename: + for i in range(1, len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_tile_gemm_basic_fp8_tflops" + if 'gemm_mem_pipeline_fp8' in filename: + for i in range(1, len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_tile_gemm_mem_pipeline_fp8_tflops" + if 'gemm_basic_bf8' in filename: + for i in range(1, len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_tile_gemm_basic_bf8_tflops" + if 'gemm_mem_pipeline_bf8' in filename: + for i in range(1, len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_tile_gemm_mem_pipeline_bf8_tflops" tflops_base = get_baseline(table_name,conn) store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, sqlEngine) diff --git a/script/process_perf_data.sh b/script/process_perf_data.sh index 815cf41e2d..fc44064874 100755 --- a/script/process_perf_data.sh +++ b/script/process_perf_data.sh @@ -43,19 +43,12 @@ file=./perf_fmha_bwd_gfx90a.log if [ -e "$file" ]; then python3 process_perf_data.py perf_fmha_bwd_gfx90a.log fi -file=./perf_tile_gemm_basic_fp16_gfx942.log -if [ -e "$file" ]; then - python3 process_perf_data.py perf_tile_gemm_basic_fp16_gfx942.log -fi -file=./perf_tile_gemm_basic_fp16_gfx90a.log -if [ -e "$file" ]; then - python3 process_perf_data.py perf_tile_gemm_basic_fp16_gfx90a.log -fi -file=./perf_tile_gemm_mem_pipeline_fp16_gfx942.log -if [ -e "$file" ]; then - python3 process_perf_data.py perf_tile_gemm_mem_pipeline_fp16_gfx942.log -fi -file=./perf_tile_gemm_mem_pipeline_fp16_gfx90a.log -if [ -e "$file" ]; then - python3 process_perf_data.py perf_tile_gemm_mem_pipeline_fp16_gfx90a.log -fi + +for gpu in "gfx90a" "gfx942"; do + for dtype in "fp16" "bf16" "fp8" "bf8"; do + file=./perf_tile_gemm_mem_pipeline_${dtype}_${gpu}.log + if [ -e "$file" ]; then + python3 process_perf_data.py perf_tile_gemm_mem_pipeline_${dtype}_${gpu}.log + fi + done +done diff --git a/script/process_qa_data.sh b/script/process_qa_data.sh index c5bc1b9a1a..420453cddc 100755 --- a/script/process_qa_data.sh +++ b/script/process_qa_data.sh @@ -52,19 +52,12 @@ file=./perf_fmha_bwd_gfx90a.log if [ -e "$file" ]; then python3 process_perf_data.py perf_fmha_bwd_gfx90a.log fi -file=./perf_gemm_basic_gfx942.log -if [ -e "$file" ]; then - python3 process_perf_data.py perf_gemm_basic_gfx942.log -fi -file=./perf_gemm_basic_gfx90a.log -if [ -e "$file" ]; then - python3 process_perf_data.py perf_gemm_basic_gfx90a.log -fi -file=./perf_gemm_mem_pipeline_gfx942.log -if [ -e "$file" ]; then - python3 process_perf_data.py perf_gemm_mem_pipeline_gfx942.log -fi -file=./perf_gemm_mem_pipeline_gfx90a.log -if [ -e "$file" ]; then - python3 process_perf_data.py perf_gemm_mem_pipeline_gfx90a.log -fi + +for gpu in "gfx90a" "gfx942"; do + for dtype in "fp16" "bf16" "fp8" "bf8"; do + file=./perf_tile_gemm_mem_pipeline_${dtype}_${gpu}.log + if [ -e "$file" ]; then + python3 process_perf_data.py perf_tile_gemm_mem_pipeline_${dtype}_${gpu}.log + fi + done +done diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 5d0929f0e4..0f787b718d 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -67,7 +67,9 @@ class TestCkTileBatchedGemm : public ::testing::Test using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem +class TestCkTileGemmPipelineCompV3 : public TestCkTileGemmPipeline +{ +}; + +#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV3 + +TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3, KernelTypesMem); + +#include "test_gemm_pipeline_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp new file mode 100644 index 0000000000..1da0028f63 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp @@ -0,0 +1,16 @@ +#include "test_gemm_pipeline_kernel_types.hpp" +#include "test_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileGemmPipelineCompV4 : public TestCkTileGemmPipeline +{ +}; + +#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV4 + +TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4, KernelTypesMem); + +#include "test_gemm_pipeline_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm/test_gemm_pipeline.cpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp similarity index 62% rename from test/ck_tile/gemm/test_gemm_pipeline.cpp rename to test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index f0236b5d88..bd1502516b 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -10,6 +10,7 @@ using F16 = ck_tile::half_t; using F32 = float; +using F8 = ck_tile::fp8_t; using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Intrawave = ck_tile::integral_constant; // clang-format off -using KernelTypes = ::testing::Types< - // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType +using KernelTypesMem = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, - std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>, - std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>, + std::tuple< Row, Row, Row, F8, F8, F32, F16, Interwave, Mem>, + std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, Mem>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, Interwave, Mem>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, Mem>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, Mem>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, Interwave, Mem>, std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, Mem>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, Interwave, Mem> +>; + +using KernelTypesCompV3 = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem> - >; + std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, CompV3> +>; + +using KernelTypesCompV4 = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4> +>; + // clang-format on - -TYPED_TEST_SUITE(TestCkTileGemmPipeline, KernelTypes); - -#include "test_gemm_pipeline_ut_cases.inc" diff --git a/test/ck_tile/gemm/test_gemm_pipeline_mem.cpp b/test/ck_tile/gemm/test_gemm_pipeline_mem.cpp new file mode 100644 index 0000000000..a7f4e68386 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_mem.cpp @@ -0,0 +1,16 @@ +#include "test_gemm_pipeline_kernel_types.hpp" +#include "test_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileGemmPipelineMem : public TestCkTileGemmPipeline +{ +}; + +#define TEST_SUITE_NAME TestCkTileGemmPipelineMem + +TYPED_TEST_SUITE(TestCkTileGemmPipelineMem, KernelTypesMem); + +#include "test_gemm_pipeline_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc index e53015a975..1f0683f8b8 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc @@ -3,7 +3,10 @@ #pragma once -TYPED_TEST(TestCkTileGemmPipeline, SmallM) +#ifndef TEST_GEMM_PIPELINE_UT_CASES_INC +#define TEST_GEMM_PIPELINE_UT_CASES_INC + +TYPED_TEST(TEST_SUITE_NAME, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 1024; @@ -13,18 +16,25 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM) { if constexpr(std::is_same_v) + { EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } else + { this->Run(M, N, K); + } } } -TYPED_TEST(TestCkTileGemmPipeline, MidLargeM) +TYPED_TEST(TEST_SUITE_NAME, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 1024; constexpr int K = 320; - constexpr int VecLoadSize = 8; + constexpr int VecLoadSize = (std::is_same_v || + std::is_same_v) + ? 16 + : 8; for(int M : Ms) { @@ -33,9 +43,13 @@ TYPED_TEST(TestCkTileGemmPipeline, MidLargeM) { // TODO: Can we anyhow deduce used vector load size? if(M % VecLoadSize == 0) + { this->Run(M, N, K); + } else + { EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } } else { @@ -44,7 +58,7 @@ TYPED_TEST(TestCkTileGemmPipeline, MidLargeM) } } -TYPED_TEST(TestCkTileGemmPipeline, PaddK) +TYPED_TEST(TEST_SUITE_NAME, PaddK) { std::vector Ms{128}; constexpr int N = 1024; @@ -54,7 +68,7 @@ TYPED_TEST(TestCkTileGemmPipeline, PaddK) this->Run(M, N, K); } -TYPED_TEST(TestCkTileGemmPipeline, Regular) +TYPED_TEST(TEST_SUITE_NAME, Regular) { std::vector Ms{512}; constexpr int N = 1024; @@ -64,7 +78,16 @@ TYPED_TEST(TestCkTileGemmPipeline, Regular) this->Run(M, N, K); } -TYPED_TEST(TestCkTileGemmPipeline, NotSupportedArgument) +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, NotSupportedArgument) { constexpr int M = 512; constexpr int N = 1025; @@ -76,3 +99,5 @@ TYPED_TEST(TestCkTileGemmPipeline, NotSupportedArgument) EXPECT_THROW((this->template Run(M, N, K)), std::runtime_error); } + +#endif diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 3a9203a5bf..1b997ddbce 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -11,6 +11,27 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + enum struct GemmPipelineType { Mem, @@ -63,7 +84,7 @@ class TestCkTileGemmPipeline : public ::testing::Test // TODO: This should be parameterized in tests constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 32; + constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; @@ -71,8 +92,6 @@ class TestCkTileGemmPipeline : public ::testing::Test constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; - // TODO: Restore to 8. At now after changes in block_universal_gemm_as_bs_cr it return wrong - // values. constexpr ck_tile::index_t K_Warp_Tile = 16; constexpr bool kPadM = PadM; @@ -136,7 +155,9 @@ class TestCkTileGemmPipeline : public ::testing::Test typename GemmPipelineTypeSelector::pipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem( a_m_k, b_k_n, c_m_n_host_ref); - pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; EXPECT_TRUE(pass); } }; diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index 6b9bf0c6f7..cd94d0b867 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -79,6 +79,8 @@ class TestCkTileGroupedGemm : public ::testing::Test template using GemmEpilogue = ck_tile::CShuffleEpilogue