From bdea637a15baeb0fca2ca213cd577738892858c4 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Tue, 23 Sep 2025 22:54:52 -0700 Subject: [PATCH] Fix the gfx950 numerical errors (#2911) * Update grouped_gemm example and pipeline * find the root cause error in did not enable the transpose in gfx950 correctly * Fix v3 pipeline, row and col major * Disable f8 datatype tests, it fails on gfx950 * fix the abd test by clear the runtime argument unsupported --------- Co-authored-by: AviralGoelAMD Co-authored-by: Mateusz Ozga [ROCm/composable_kernel commit: b159841a06eaee568ad8336603b0b00ff38a7314] --- .../quant_run_grouped_gemm_example.inc | 12 +++ .../run_grouped_gemm_example.inc | 16 ++- .../ops/gemm/kernel/gemm_multi_abd_kernel.hpp | 12 +++ .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 17 ++-- .../gemm/pipeline/gemm_pipeline_problem.hpp | 4 +- test/ck_tile/gemm_multi_abd/CMakeLists.txt | 8 +- .../test_gemm_multi_abd_cshuffle.cpp | 8 +- .../test_gemm_multi_abd_default2d.cpp | 10 +- .../test_gemm_multi_abd_ut_cases_cshuffle.inc | 99 ------------------- 9 files changed, 63 insertions(+), 123 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 17e0ee5342..658a4dfa62 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -183,12 +183,24 @@ int run_grouped_gemm_example_with_layouts(int argc, if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs)) { std::cout << "Please check the input data. Default values will be used." << std::endl; + + // Clear existing (invalid) data before adding defaults + Ms.clear(); + Ns.clear(); + Ks.clear(); + stride_As.clear(); + stride_Bs.clear(); + stride_Cs.clear(); + stride_AQs.clear(); + stride_BQs.clear(); + for(int i = 0; i < group_count; i++) { Ms.push_back(256 + 256 * i); Ns.push_back(256 + 512 * i); Ks.push_back(512 + 128 * i); + // Let get_default_stride calculate based on layout stride_As.push_back(0); stride_Bs.push_back(0); stride_Cs.push_back(0); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 1cd2212994..026f2bd8f6 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -172,15 +172,25 @@ int run_grouped_gemm_example_with_layouts(int argc, std::cout << "Default values: Ms (256, 512, 768, 1024..), Ns (256, 768, 1280..), Ks (512, " "896, 1280..)" << std::endl; + + // Clear existing (invalid) data before adding defaults + Ms.clear(); + Ns.clear(); + Ks.clear(); + stride_As.clear(); + stride_Bs.clear(); + stride_Cs.clear(); + for(int i = 0; i < group_count; i++) { Ms.push_back(256 + 256 * i); Ns.push_back(256 + 512 * i); Ks.push_back(512 + 384 * i); - stride_As.push_back(Ks[i]); - stride_Bs.push_back(Ks[i]); - stride_Cs.push_back(Ns[i]); + // Set default strides based on layout later using get_default_stride + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); } } diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp index 3b050e03ed..b4ddc33e8d 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp @@ -132,6 +132,10 @@ struct GemmKernelMultiABD static constexpr index_t NumBTensor = BsDataType::size(); static constexpr index_t NumDTensor = DsDataType::size(); + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + using DDataType = remove_cvref_t>; + CK_TILE_HOST static auto GetName() -> const std::string { return UniversalGemmKernel::GetName(); @@ -181,6 +185,14 @@ struct GemmKernelMultiABD { return false; } + // Currently MultiABD kernel doesn't support F8 data type + if(ck_tile::get_device_name() == "gfx950" && + (std::is_same::value || + std::is_same::value || + std::is_same::value)) + { + return false; + } return UniversalGemmKernel::IsSupportedArgument(kargs); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 7159eda683..2b0b2e8488 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -530,7 +530,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 move_tile_window(b_copy_dram_window, b_dram_tile_window_step); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); __builtin_amdgcn_sched_barrier(0); @@ -542,7 +543,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { block_sync_lds(); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -553,7 +554,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -577,7 +578,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); @@ -596,7 +598,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -607,7 +609,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -619,7 +621,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } // __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index c73fa29245..75790afecd 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -100,7 +100,7 @@ struct GemmPipelineProblemBase { constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize; @@ -118,7 +118,7 @@ struct GemmPipelineProblemBase { constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize; diff --git a/test/ck_tile/gemm_multi_abd/CMakeLists.txt b/test/ck_tile/gemm_multi_abd/CMakeLists.txt index ac3b59d5d3..8f9b694a3b 100644 --- a/test/ck_tile/gemm_multi_abd/CMakeLists.txt +++ b/test/ck_tile/gemm_multi_abd/CMakeLists.txt @@ -5,8 +5,8 @@ if(CK_USE_OCP_FP8) endif() if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") - add_gtest_executable(test_gemm_multi_abd_cshuffle test_gemm_multi_abd_cshuffle.cpp) - add_gtest_executable(test_gemm_multi_abd_default2d test_gemm_multi_abd_default2d.cpp) - target_compile_definitions(test_gemm_multi_abd_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - target_compile_definitions(test_gemm_multi_abd_default2d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_gemm_multi_abd_cshuffle test_gemm_multi_abd_cshuffle.cpp) + add_gtest_executable(test_ck_tile_gemm_multi_abd_default2d test_gemm_multi_abd_default2d.cpp) + target_compile_definitions(test_ck_tile_gemm_multi_abd_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_definitions(test_ck_tile_gemm_multi_abd_default2d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp index 9821963458..87d6a9101c 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp @@ -24,14 +24,16 @@ using KernelTypes = ::testing::Types< std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> + + // Currently MultiABD kernel doesn't support F8 data type + //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, >; // clang-format on diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp index b3a89aba05..f2476e803f 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp @@ -22,17 +22,17 @@ using KernelTypes = ::testing::Types< // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> + + // Currently MultiABD kernel doesn't support F8 data type + //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, >; // clang-format on diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc index 5aa113608f..e9a8ed74f2 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc @@ -1,104 +1,5 @@ #pragma once -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x512x256) -{ - constexpr int M = 256; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x256x256) -{ - constexpr int M = 512; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x512x256) -{ - constexpr int M = 512; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x256x256) -{ - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x768x256) -{ - constexpr int M = 512; - constexpr int N = 768; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x1280x256) -{ - constexpr int M = 512; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x1280x256) -{ - constexpr int M = 256; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_768x512x256) -{ - constexpr int M = 768; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x512x256) -{ - constexpr int M = 1280; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x256x256) -{ - constexpr int M = 1280; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x512) {