From 167e5ab3b512c47fac6d2f2d77946ab9be2ce110 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Wed, 24 Sep 2025 06:15:34 +0000 Subject: [PATCH] Merge commit 'dcd33a6ecc30e18cc8491ed03926ab5ac8b6f1c3' into develop --- .../quant_run_grouped_gemm_example.inc | 12 +++ .../run_grouped_gemm_example.inc | 16 ++- .../arch/amd_buffer_addressing_builtins.hpp | 54 ++++++++++ include/ck_tile/core/tensor/load_tile.hpp | 3 - .../ops/epilogue/cshuffle_epilogue.hpp | 20 +--- .../ops/gemm/kernel/gemm_multi_abd_kernel.hpp | 12 +++ .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 17 ++-- .../gemm/pipeline/gemm_pipeline_problem.hpp | 4 +- test/ck_tile/gemm_multi_abd/CMakeLists.txt | 8 +- .../test_gemm_multi_abd_cshuffle.cpp | 8 +- .../test_gemm_multi_abd_default2d.cpp | 10 +- .../test_gemm_multi_abd_ut_cases_cshuffle.inc | 99 ------------------- 12 files changed, 121 insertions(+), 142 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 17e0ee5342..658a4dfa62 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -183,12 +183,24 @@ int run_grouped_gemm_example_with_layouts(int argc, if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs)) { std::cout << "Please check the input data. Default values will be used." << std::endl; + + // Clear existing (invalid) data before adding defaults + Ms.clear(); + Ns.clear(); + Ks.clear(); + stride_As.clear(); + stride_Bs.clear(); + stride_Cs.clear(); + stride_AQs.clear(); + stride_BQs.clear(); + for(int i = 0; i < group_count; i++) { Ms.push_back(256 + 256 * i); Ns.push_back(256 + 512 * i); Ks.push_back(512 + 128 * i); + // Let get_default_stride calculate based on layout stride_As.push_back(0); stride_Bs.push_back(0); stride_Cs.push_back(0); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 1cd2212994..026f2bd8f6 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -172,15 +172,25 @@ int run_grouped_gemm_example_with_layouts(int argc, std::cout << "Default values: Ms (256, 512, 768, 1024..), Ns (256, 768, 1280..), Ks (512, " "896, 1280..)" << std::endl; + + // Clear existing (invalid) data before adding defaults + Ms.clear(); + Ns.clear(); + Ks.clear(); + stride_As.clear(); + stride_Bs.clear(); + stride_Cs.clear(); + for(int i = 0; i < group_count; i++) { Ms.push_back(256 + 256 * i); Ns.push_back(256 + 512 * i); Ks.push_back(512 + 384 * i); - stride_As.push_back(Ks[i]); - stride_Bs.push_back(Ks[i]); - stride_Cs.push_back(Ns[i]); + // Set default strides based on layout later using get_default_stride + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); } } diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 5c7ffefc6a..4e0a86119a 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -2570,6 +2570,60 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer& src_thread_ #endif } +// amd_wave_read_first_lane is the SGPR function from AMD GPU device to load 1 or a series of the +// memory to the SGPR registers. +__device__ inline uint32_t amd_wave_read_first_lane(uint16_t v) +{ + return __builtin_amdgcn_readfirstlane(static_cast(v)); +} + +__device__ inline uint32_t amd_wave_read_first_lane(uint8_t v) +{ + return __builtin_amdgcn_readfirstlane(static_cast(v)); +} + +__device__ inline uint32_t amd_wave_read_first_lane(uint32_t value) +{ + return __builtin_amdgcn_readfirstlane(value); +} + +__device__ inline int32_t amd_wave_read_first_lane(int32_t value) +{ + return __builtin_amdgcn_readfirstlane(value); +} + +template , int> = 0> +__device__ inline auto amd_wave_read_first_lane(const Object& obj) +{ + constexpr size_t ObjectSize = sizeof(Object); + constexpr size_t SGPR_size = 4; + constexpr size_t NumFull = ObjectSize / SGPR_size; + constexpr size_t Tail = ObjectSize % SGPR_size; + + const unsigned char* src = reinterpret_cast(&obj); + alignas(Object) unsigned char dst[ObjectSize]; + + static_for<0, NumFull, 1>{}([&](auto Ic) { + constexpr size_t offset = Ic * SGPR_size; + uint32_t read_src; + __builtin_memcpy(&read_src, src + offset, SGPR_size); + read_src = __builtin_amdgcn_readfirstlane(read_src); + __builtin_memcpy(dst + offset, &read_src, SGPR_size); + }); + + if constexpr(Tail != 0) + { + constexpr size_t offset = NumFull * SGPR_size; + uint32_t tail_loc = 0; + __builtin_memcpy(&tail_loc, src + offset, Tail); + tail_loc = __builtin_amdgcn_readfirstlane(tail_loc); + __builtin_memcpy(dst + offset, &tail_loc, Tail); + } + Object out; + __builtin_memcpy(&out, dst, ObjectSize); + return out; +} + template CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, const index_t global_offset, diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index a3620453b4..2e9ab0f5c6 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -158,7 +158,4 @@ CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window -concept IsLoadableTile = requires { load_tile(std::declval()); }; - } // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 6c815d804d..585a5f5b42 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -481,13 +481,10 @@ struct CShuffleEpilogue auto sm_tile = make_static_distributed_tensor(dram_tile_distribution); auto sn_tile = make_static_distributed_tensor(dram_tile_distribution); - // Build windows only if scales are provided + // Build windows only if non-scalar scales are provided auto scale_m_window = [&]() { if constexpr(has_scales && !has_scalar_scales) { - static_assert( - IsLoadableTile, - "ScaleM must be a loadable tile"); return make_tile_window(scale_m, dram_tile_distribution); } else @@ -498,9 +495,6 @@ struct CShuffleEpilogue auto scale_n_window = [&]() { if constexpr(has_scales && !has_scalar_scales) { - static_assert( - IsLoadableTile, - "ScaleN must be a loadable tile"); return make_tile_window(scale_n, dram_tile_distribution); } else @@ -515,8 +509,8 @@ struct CShuffleEpilogue merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); - // If scales provided, load them with identical distribution - if constexpr(has_scales && IsLoadableTile && IsLoadableTile) + // If non-scalar scales provided, load them with identical distribution + if constexpr(has_scales && !has_scalar_scales) { sm_tile = load_tile(scale_m_window); // row scales in permuted layout sn_tile = load_tile(scale_n_window); // col scales in permuted layout @@ -535,7 +529,7 @@ struct CShuffleEpilogue { v = static_cast(v * scale_m * scale_n); } - else if constexpr(has_scales) + else if constexpr(has_scales && !has_scalar_scales) { // same linear index mapping on the permuted distribution const auto s_m = static_cast(sm_tile.get_thread_buffer()[out_idx]); @@ -636,9 +630,6 @@ struct CShuffleEpilogue } else if constexpr(has_scales) { - static_assert( - IsLoadableTile, - "ScaleM must be a loadable tile"); return make_tile_window(scale_m, lds_tile.get_tile_distribution()); } else @@ -653,9 +644,6 @@ struct CShuffleEpilogue } else if constexpr(has_scales) { - static_assert( - IsLoadableTile, - "ScaleN must be a loadable tile"); return make_tile_window(scale_n, lds_tile.get_tile_distribution()); } else diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp index 3b050e03ed..b4ddc33e8d 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp @@ -132,6 +132,10 @@ struct GemmKernelMultiABD static constexpr index_t NumBTensor = BsDataType::size(); static constexpr index_t NumDTensor = DsDataType::size(); + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + using DDataType = remove_cvref_t>; + CK_TILE_HOST static auto GetName() -> const std::string { return UniversalGemmKernel::GetName(); @@ -181,6 +185,14 @@ struct GemmKernelMultiABD { return false; } + // Currently MultiABD kernel doesn't support F8 data type + if(ck_tile::get_device_name() == "gfx950" && + (std::is_same::value || + std::is_same::value || + std::is_same::value)) + { + return false; + } return UniversalGemmKernel::IsSupportedArgument(kargs); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 7159eda683..2b0b2e8488 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -530,7 +530,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 move_tile_window(b_copy_dram_window, b_dram_tile_window_step); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); __builtin_amdgcn_sched_barrier(0); @@ -542,7 +543,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { block_sync_lds(); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -553,7 +554,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -577,7 +578,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); @@ -596,7 +598,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -607,7 +609,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -619,7 +621,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } // __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index c73fa29245..75790afecd 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -100,7 +100,7 @@ struct GemmPipelineProblemBase { constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize; @@ -118,7 +118,7 @@ struct GemmPipelineProblemBase { constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize; diff --git a/test/ck_tile/gemm_multi_abd/CMakeLists.txt b/test/ck_tile/gemm_multi_abd/CMakeLists.txt index ac3b59d5d3..8f9b694a3b 100644 --- a/test/ck_tile/gemm_multi_abd/CMakeLists.txt +++ b/test/ck_tile/gemm_multi_abd/CMakeLists.txt @@ -5,8 +5,8 @@ if(CK_USE_OCP_FP8) endif() if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") - add_gtest_executable(test_gemm_multi_abd_cshuffle test_gemm_multi_abd_cshuffle.cpp) - add_gtest_executable(test_gemm_multi_abd_default2d test_gemm_multi_abd_default2d.cpp) - target_compile_definitions(test_gemm_multi_abd_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - target_compile_definitions(test_gemm_multi_abd_default2d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_gemm_multi_abd_cshuffle test_gemm_multi_abd_cshuffle.cpp) + add_gtest_executable(test_ck_tile_gemm_multi_abd_default2d test_gemm_multi_abd_default2d.cpp) + target_compile_definitions(test_ck_tile_gemm_multi_abd_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_definitions(test_ck_tile_gemm_multi_abd_default2d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp index 9821963458..87d6a9101c 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp @@ -24,14 +24,16 @@ using KernelTypes = ::testing::Types< std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> + + // Currently MultiABD kernel doesn't support F8 data type + //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, >; // clang-format on diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp index b3a89aba05..f2476e803f 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp @@ -22,17 +22,17 @@ using KernelTypes = ::testing::Types< // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> + + // Currently MultiABD kernel doesn't support F8 data type + //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, >; // clang-format on diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc index 5aa113608f..e9a8ed74f2 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc @@ -1,104 +1,5 @@ #pragma once -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x512x256) -{ - constexpr int M = 256; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x256x256) -{ - constexpr int M = 512; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x512x256) -{ - constexpr int M = 512; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x256x256) -{ - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x768x256) -{ - constexpr int M = 512; - constexpr int N = 768; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x1280x256) -{ - constexpr int M = 512; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x1280x256) -{ - constexpr int M = 256; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_768x512x256) -{ - constexpr int M = 768; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x512x256) -{ - constexpr int M = 1280; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x256x256) -{ - constexpr int M = 1280; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x512) {